In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
torch.set_printoptions(linewidth=120)
import skimage.io as io
import cv2 as cv2
import os
from torch.utils.tensorboard import SummaryWriter
print(torch.__version__)
print(torchvision.__version__)

1.4.0
0.5.0


In [2]:
def prepare_data(batch_size,path):
    data_dir = path
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(device)

    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(100),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(100),
            transforms.CenterCrop(100),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
         'test': transforms.Compose([
            transforms.Resize(100),
            transforms.CenterCrop(100),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }


    folders=os.listdir(os.path.join(data_dir, 'train'))
    data_set = {x: torchvision.datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x])
                  for x in ['train', 'val','test']}
    train_loader = {x: torch.utils.data.DataLoader(data_set[x], batch_size=batch_size,
                                                 shuffle=True, num_workers=4)
                  for x in ['train', 'val','test']}
    return folders,data_set,train_loader,device


In [3]:
def get_num_correct(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [4]:
#class focal loss to calculate the loss depending on the hard examples
#focal loss class is for multi-label classification
from torch.autograd import Variable
def one_hot(index, classes,device):
    size = index.size() + (classes,)
    view = index.size() + (1,)

    mask = torch.Tensor(*size).fill_(0)
    mask=mask.to(device)
    index = index.view(*view)
    index=index.to(device)
    ones = 1.

    if isinstance(index, Variable):
        ones = Variable(torch.Tensor(index.size()).fill_(1))
        ones=ones.to(device)
        mask = Variable(mask, volatile=index.volatile)

    return mask.scatter_(1, index, ones)
class FocalLoss(nn.Module):

    def __init__(self, gamma=0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps

    def forward(self, input, target,device):
        y = one_hot(target, input.size(-1),device)
        logit = F.softmax(input, dim=-1)
        logit = logit.clamp(self.eps, 1. - self.eps)

        loss = -1 * y * torch.log(logit) # cross entropy
        loss = loss * (1 - logit) ** self.gamma # focal loss

        return loss.sum()


In [5]:
def get_resnet(n,device):
    resnet50 = models.resnet50(pretrained=True)
    for param in resnet50.parameters():
        param.require_grad = False

    num_ftrs = resnet50.fc.in_features
    resnet50.fc = nn.Linear(num_ftrs, n)
    resnet50 = resnet50.to(device)
    return resnet50

In [6]:
def model(resnet50,lr,n,device,train_loader,batch_size,epoch_num,tb):
    optimizer= optim.Adam(resnet50.parameters(),lr=lr)
    focal_loss_multilabel=FocalLoss()
    for epoch in range (epoch_num):
        total_loss=0
        total_correct=0
        #images, labels represent current batch
        for images,labels in train_loader['train']:
            images = images.to(device)
            labels = labels.to(device)
#             grid=torchvision.utils.make_grid(images)
#             tb.add_images('images',grid, dataformats='CHW')
#             tb.add_graph(resnet50,images)
#             tb.close()
            preds=resnet50(images)
            loss=None
            if(n>=3):
                #if you want to calculate it using the normal cross entropy
#                 loss = F.cross_entropy(preds,labels)
                #if you want to calculate it using the focal loss
                loss=focal_loss_multilabel.forward(preds,labels,device)
            else:
                #if you want to calculate it using the normal cross entropy
#                 loss = F.binary_cross_entropy(preds,labels)
                #if you want to calculate it using the focal loss
                BCE_loss = F.binary_cross_entropy_with_logits(preds, labels, reduction='none')
                pt = torch.exp(-BCE_loss) # prevents nans when probability 0
                Focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
                loss= Focal_loss.mean()
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss+=loss.item()
            total_correct+=get_num_correct(preds,labels)
            print ("epoch: ",epoch," total_correct: ",total_correct," total_loss: ",total_loss)
        tb.add_scalar('loss',total_loss,epoch)
        tb.add_scalar('number correct',total_correct,epoch)
        tb.add_scalar('accuracy',total_correct/(len(train_loader['train'])*batch_size),epoch)
    return total_correct,total_loss

In [7]:
def print_train_accuracy(total_correct,train_loader,batch_size):
    print("train accuracy: ",total_correct/(len(train_loader['train'])*batch_size))

In [8]:
@torch.no_grad()
def validation(resnet50,device,train_loader,n,tb,batch_size):
    total_loss=0
    total_correct=0
    focal_loss_multilabel=FocalLoss()
    for images,labels in train_loader['val']:
        images = images.to(device)
        labels = labels.to(device)
        preds=resnet50(images)
        loss=None
        
        if(n>=3):
            #if you want to calculate it using the normal cross entropy
#                 loss = F.cross_entropy(preds,labels)
            #if you want to calculate it using the focal loss
            loss=focal_loss_multilabel.forward(preds,labels,device)
        else:
            #if you want to calculate it using the normal cross entropy
#                 loss = F.binary_cross_entropy(preds,labels)
            #if you want to calculate it using the focal loss
            BCE_loss = F.binary_cross_entropy_with_logits(preds, labels, reduction='none')
            pt = torch.exp(-BCE_loss) # prevents nans when probability 0
            Focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
            loss= Focal_loss.mean()
            
        total_loss+=loss.item()
        total_correct+=get_num_correct(preds,labels)
        print ("total_correct: ",total_correct," total_loss: ",total_loss)
    tb.add_scalar('loss',total_loss,batch_size)
    tb.add_scalar('number correct',total_correct,batch_size)
    tb.add_scalar('accuracy',total_correct/(len(train_loader['train'])*batch_size),batch_size)
    return total_correct,total_loss

In [9]:
def print_validation_accuracy(total_correct,train_loader,batch_size):
    print("validation accuracy: ",total_correct/(len(train_loader['val'])*batch_size))

In [10]:
@torch.no_grad()
def testing(resnet50,device,train_loader,n):
    total_loss=0
    total_correct=0
    focal_loss_multilabel=FocalLoss()
    for images,labels in train_loader['test']:
        images = images.to(device)
        labels = labels.to(device)
        preds=resnet50(images)
        loss=None
        if(n>=3):
            #if you want to calculate it using the normal cross entropy
#                 loss = F.cross_entropy(preds,labels)
            #if you want to calculate it using the focal loss
            loss=focal_loss_multilabel.forward(preds,labels,device)
        else:
            #if you want to calculate it using the normal cross entropy
#                 loss = F.binary_cross_entropy(preds,labels)
            #if you want to calculate it using the focal loss
            BCE_loss = F.binary_cross_entropy_with_logits(preds, labels, reduction='none')
            pt = torch.exp(-BCE_loss) # prevents nans when probability 0
            Focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
            loss= Focal_loss.mean()
        total_loss+=loss.item()
        total_correct+=get_num_correct(preds,labels)
        print ("total_correct: ",total_correct," total_loss: ",total_loss)
    return total_correct,total_loss

In [11]:
def print_testing_accuracy(total_correct,train_loader,batch_size):
    print("testing_accuracy: ",total_correct/(len(train_loader['test'])*batch_size))

In [12]:
def main():
#     batch_size_list=[32,64,128]
#     lr_list=[0.0001,0.00001,0.000001]
#     epoch_num_list=[10,17,23]
#     for batch_size in batch_size_list :
#         for lr in lr_list :
#             for epoch_num in epoch_num_list:
                batch_size=64
                lr=0.00001
                epoch_num=23
                folder,data_set,train_loader,device=prepare_data(batch_size,"/home/ahmed/intern work/image classification/animal-image-datasetdog-cat-and-panda")
                print(len(folder))
                print("no of data = ",len(train_loader['train'])*batch_size)
                n=len(folder)
                tb=SummaryWriter()
                resnet50=get_resnet(n,device)
                total_correct,total_loss=model(resnet50,lr,n,device,train_loader,batch_size,epoch_num,tb)
                print_train_accuracy(total_correct,train_loader,batch_size)
                total_correct,total_loss=validation(resnet50,device,train_loader,n,tb,batch_size)
                print_validation_accuracy(total_correct,train_loader,batch_size)
                total_correct,total_loss=testing(resnet50,device,train_loader,n)
                print_testing_accuracy(total_correct,train_loader,batch_size)

In [13]:
main()

cuda:0
3
no of data =  2176




epoch:  0  total_correct:  30  total_loss:  68.32234191894531
epoch:  0  total_correct:  53  total_loss:  138.0527114868164
epoch:  0  total_correct:  83  total_loss:  202.70592498779297
epoch:  0  total_correct:  119  total_loss:  265.53893661499023
epoch:  0  total_correct:  157  total_loss:  326.8213768005371
epoch:  0  total_correct:  191  total_loss:  390.432315826416
epoch:  0  total_correct:  221  total_loss:  453.9290657043457
epoch:  0  total_correct:  259  total_loss:  515.063892364502
epoch:  0  total_correct:  307  total_loss:  570.3730430603027
epoch:  0  total_correct:  354  total_loss:  626.4813804626465
epoch:  0  total_correct:  396  total_loss:  684.331356048584
epoch:  0  total_correct:  443  total_loss:  736.4275093078613
epoch:  0  total_correct:  492  total_loss:  788.6994590759277
epoch:  0  total_correct:  539  total_loss:  841.132137298584
epoch:  0  total_correct:  585  total_loss:  892.5668067932129
epoch:  0  total_correct:  631  total_loss:  944.43903732299

epoch:  3  total_correct:  1669  total_loss:  400.10960388183594
epoch:  3  total_correct:  1728  total_loss:  413.2683849334717
epoch:  3  total_correct:  1785  total_loss:  429.05384826660156
epoch:  3  total_correct:  1841  total_loss:  445.2961730957031
epoch:  3  total_correct:  1898  total_loss:  461.6506690979004
epoch:  3  total_correct:  1955  total_loss:  479.9057807922363
epoch:  3  total_correct:  1970  total_loss:  490.02149295806885
epoch:  4  total_correct:  62  total_loss:  10.414356231689453
epoch:  4  total_correct:  124  total_loss:  22.171344757080078
epoch:  4  total_correct:  185  total_loss:  34.287240982055664
epoch:  4  total_correct:  247  total_loss:  46.270562171936035
epoch:  4  total_correct:  307  total_loss:  57.27195644378662
epoch:  4  total_correct:  365  total_loss:  73.87461757659912
epoch:  4  total_correct:  424  total_loss:  86.50844287872314
epoch:  4  total_correct:  488  total_loss:  92.90542125701904
epoch:  4  total_correct:  550  total_loss

epoch:  7  total_correct:  1300  total_loss:  134.9877188205719
epoch:  7  total_correct:  1362  total_loss:  145.59894013404846
epoch:  7  total_correct:  1426  total_loss:  150.85679602622986
epoch:  7  total_correct:  1487  total_loss:  158.02488112449646
epoch:  7  total_correct:  1544  total_loss:  171.81602835655212
epoch:  7  total_correct:  1605  total_loss:  177.9944303035736
epoch:  7  total_correct:  1666  total_loss:  183.2606647014618
epoch:  7  total_correct:  1727  total_loss:  192.5253827571869
epoch:  7  total_correct:  1788  total_loss:  200.0685431957245
epoch:  7  total_correct:  1851  total_loss:  205.11783957481384
epoch:  7  total_correct:  1909  total_loss:  215.60123324394226
epoch:  7  total_correct:  1970  total_loss:  223.12147784233093
epoch:  7  total_correct:  2034  total_loss:  226.4083812236786
epoch:  7  total_correct:  2050  total_loss:  229.77667140960693
epoch:  8  total_correct:  63  total_loss:  4.183136940002441
epoch:  8  total_correct:  127  to

epoch:  11  total_correct:  817  total_loss:  59.671372413635254
epoch:  11  total_correct:  879  total_loss:  65.76745510101318
epoch:  11  total_correct:  943  total_loss:  68.41450214385986
epoch:  11  total_correct:  1007  total_loss:  72.47200965881348
epoch:  11  total_correct:  1069  total_loss:  77.33892393112183
epoch:  11  total_correct:  1131  total_loss:  82.83236074447632
epoch:  11  total_correct:  1195  total_loss:  86.17287826538086
epoch:  11  total_correct:  1257  total_loss:  90.54490566253662
epoch:  11  total_correct:  1318  total_loss:  96.32688856124878
epoch:  11  total_correct:  1382  total_loss:  98.78924798965454
epoch:  11  total_correct:  1443  total_loss:  107.78089094161987
epoch:  11  total_correct:  1506  total_loss:  112.52269506454468
epoch:  11  total_correct:  1569  total_loss:  115.82120013237
epoch:  11  total_correct:  1630  total_loss:  122.62456011772156
epoch:  11  total_correct:  1690  total_loss:  129.52979350090027
epoch:  11  total_correct

epoch:  15  total_correct:  250  total_loss:  13.618239164352417
epoch:  15  total_correct:  312  total_loss:  18.35962986946106
epoch:  15  total_correct:  374  total_loss:  22.515104055404663
epoch:  15  total_correct:  437  total_loss:  26.29609227180481
epoch:  15  total_correct:  500  total_loss:  31.803646326065063
epoch:  15  total_correct:  564  total_loss:  33.75981783866882
epoch:  15  total_correct:  628  total_loss:  35.908292055130005
epoch:  15  total_correct:  690  total_loss:  41.57625603675842
epoch:  15  total_correct:  753  total_loss:  45.70236849784851
epoch:  15  total_correct:  816  total_loss:  48.055638551712036
epoch:  15  total_correct:  880  total_loss:  50.840845823287964
epoch:  15  total_correct:  942  total_loss:  53.98202919960022
epoch:  15  total_correct:  1006  total_loss:  56.815338373184204
epoch:  15  total_correct:  1070  total_loss:  58.66062569618225
epoch:  15  total_correct:  1130  total_loss:  65.31295990943909
epoch:  15  total_correct:  11

epoch:  18  total_correct:  1814  total_loss:  130.7335329055786
epoch:  18  total_correct:  1876  total_loss:  138.02287578582764
epoch:  18  total_correct:  1939  total_loss:  141.4637634754181
epoch:  18  total_correct:  2002  total_loss:  147.02098441123962
epoch:  18  total_correct:  2066  total_loss:  148.70085334777832
epoch:  18  total_correct:  2084  total_loss:  150.34628438949585
epoch:  19  total_correct:  62  total_loss:  5.729355812072754
epoch:  19  total_correct:  126  total_loss:  8.624703884124756
epoch:  19  total_correct:  188  total_loss:  13.53842306137085
epoch:  19  total_correct:  249  total_loss:  20.395144939422607
epoch:  19  total_correct:  312  total_loss:  24.776050090789795
epoch:  19  total_correct:  376  total_loss:  26.314030170440674
epoch:  19  total_correct:  440  total_loss:  27.94584321975708
epoch:  19  total_correct:  503  total_loss:  31.652157306671143
epoch:  19  total_correct:  567  total_loss:  33.89088416099548
epoch:  19  total_correct: 

epoch:  22  total_correct:  1259  total_loss:  59.18312609195709
epoch:  22  total_correct:  1322  total_loss:  60.77596986293793
epoch:  22  total_correct:  1383  total_loss:  66.71681678295135
epoch:  22  total_correct:  1446  total_loss:  70.94444072246552
epoch:  22  total_correct:  1509  total_loss:  75.96457612514496
epoch:  22  total_correct:  1572  total_loss:  79.06479489803314
epoch:  22  total_correct:  1632  total_loss:  86.83328187465668
epoch:  22  total_correct:  1694  total_loss:  96.16174447536469
epoch:  22  total_correct:  1757  total_loss:  97.85346758365631
epoch:  22  total_correct:  1820  total_loss:  100.05451786518097
epoch:  22  total_correct:  1883  total_loss:  105.79401648044586
epoch:  22  total_correct:  1947  total_loss:  107.6103367805481
epoch:  22  total_correct:  2010  total_loss:  111.19250988960266
epoch:  22  total_correct:  2073  total_loss:  115.01529955863953
epoch:  22  total_correct:  2092  total_loss:  115.29491311311722
train accuracy:  0.9

NameError: name 'focal_loss_multilabel' is not defined