In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader 
import os
import numpy as np
from tqdm import tqdm
from PIL import Image

In [2]:
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import math

class CrossEntropyLabelSmooth(nn.Module):
    """Cross entropy loss with label smoothing regularizer.
    Reference:
    Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
    Equation: y = (1 - epsilon) * y + epsilon / K.
    Args:
        num_classes (int): number of classes.
        epsilon (float): weight.
    """
    def __init__(self, num_classes, epsilon=0.1, device='cpu'):
        super(CrossEntropyLabelSmooth, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.device = device
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (num_classes)
        """
        log_probs = self.logsoftmax(inputs)
        # targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data, 1)# for mldg da
        targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)#for zzd
        targets = targets.to(self.device)
        targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (-Variable(targets) * log_probs).mean(0).sum()
        return loss

class TripletLoss(nn.Module):
    """Triplet loss with hard positive/negative mining.
    Reference:
    Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
    Args:
        margin (float): margin for triplet.
    """
    def __init__(self, margin=0.3):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)

    def forward(self, inputs, targets):
        """
        Args:
            inputs: feature matrix with shape (batch_size, feat_dim)
            targets: ground truth labels with shape (num_classes)
        """
        n = inputs.size(0)
        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()
        dist.addmm_(1, -2, inputs, inputs.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return loss

class CenterLoss(nn.Module):
    """Center loss.
    
    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
    
    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
    def __init__(self, num_classes=10, feat_dim=2048, device='cpu'):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.device = device

        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)).to(self.device)

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (num_classes).
        """
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()
        classes = classes.to(self.device)
        
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.data.eq(classes.expand(batch_size, self.num_classes))
        dist = []
        for i in range(batch_size):
            value = distmat[i][mask[i]]
            value = value.clamp(min=1e-12, max=1e+12) # for numerical stability
            dist.append(value)
        dist = torch.cat(dist)
        loss = dist.mean()

        return loss

In [3]:
class FoodDataset(Dataset):
    def __init__(self, file, transform=None, mode='train'):
        self.transforms = transform
        self.mode = mode
        with open(file, 'r') as f:
            self.image_list = f.readlines()

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        label = None
        if self.mode == 'train':
            image, label = self.image_list[index].split('\n')[0].split('\t')
            label = int(label)
        else:
            image = self.image_list[index].split('\n')[0]
        image = Image.open(image).convert('RGB')
        image = self.transforms(image)
        if self.mode == 'train':
            return image, label
        else:
            return image

In [4]:
transforms_train = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.Pad(10, 10),
                transforms.RandomRotation(45),
                transforms.RandomCrop((224, 224)),
                transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

transforms_test = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

In [5]:
def evaluate(prediction, ground_truth):
    num_correct = (np.array(prediction) == np.array(ground_truth)).sum()
    return num_correct / len(prediction)

In [6]:
train_ds = FoodDataset('/media/ntu/volume2/home/s121md302_07/food/data/train.txt', transform=transforms_train)
val_ds = FoodDataset('/media/ntu/volume2/home/s121md302_07/food/data/val.txt', transform=transforms_test)
test_ds = FoodDataset('/media/ntu/volume2/home/s121md302_07/food/data/test.txt', transform=transforms_test)

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=8)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=True, num_workers=8)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=True, num_workers=8)

In [7]:
num_classes = 5
train_model = models.vgg19(pretrained=True)
# train_model.fc = nn.Linear(512, num_classes)
train_model.classifier[6] = nn.Linear(4096, num_classes)

model_str = 'mobilenetv2'
output_dir = 'checkpoint_' + model_str
if output_dir and not os.path.exists(output_dir):
    os.makedirs(output_dir)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ce_loss = CrossEntropyLabelSmooth(num_classes = num_classes, device = device)
optimizer = torch.optim.Adam(train_model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

In [8]:
for param in train_model.parameters():
    param.requires_grad = False
for param in train_model.classifier.parameters():
    param.requires_grad = True
for i in range(5):
    train_model.train()
    train_model.to(device)
    for img, label in tqdm(train_dl):
        img = img.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        output= train_model(img) 
        loss = ce_loss(output, label) 
        loss.backward()
        optimizer.step()


100%|██████████| 16/16 [00:01<00:00,  8.06it/s]
100%|██████████| 16/16 [00:01<00:00,  8.42it/s]
100%|██████████| 16/16 [00:01<00:00,  8.43it/s]
100%|██████████| 16/16 [00:01<00:00,  8.58it/s]
100%|██████████| 16/16 [00:01<00:00,  8.56it/s]


In [9]:
for param in train_model.parameters():
    param.requires_grad = True
epoch = 100
highest_acc = {'epoch': 0, 'accuracy': 0}
for ep in range(epoch):
    train_model.train()
    train_model.to(device)
    count = 0
    running_loss = 0.0
    validation_loss = 0.0
    output_list = []
    ground_truth_list = []
    for img, label in tqdm(train_dl):
        img = img.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        output= train_model(img)
        loss = ce_loss(output, label)
        count += 1
        prediction = torch.argmax(output, dim=1)
        output_list.extend(prediction.detach().cpu())
        ground_truth_list.extend(label.cpu())
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    
    scheduler.step()
        
    # if ep % 10 == 0:
    #     torch.save(train_model.state_dict(), output_dir + '/' + model_str + '_' + str(ep) + '.pth')
        
    accuracy = evaluate(output_list, ground_truth_list)
    print(f'Epoch[{ep}] training accuracy: {accuracy} '
            f'training loss: {running_loss / count:.3e} Base Lr: {optimizer.param_groups[0]["lr"]:.5e}')

    if ep % 10 == 0:
        train_model.eval()
        count = 0
        output_list = []
        ground_truth_list = []
        for img, label in tqdm(val_dl):
            with torch.no_grad():
                img = img.to(device)
                lbl = label.to(device)

                output= train_model(img)

                val_loss = ce_loss(output, lbl)

                validation_loss += val_loss.item()
                count += 1
                prediction = torch.argmax(output, dim=1)
                output_list.extend(prediction.detach().cpu())
                ground_truth_list.extend(label)
        accuracy = evaluate(output_list, ground_truth_list)
        if accuracy > highest_acc['accuracy']:
            highest_acc['accuracy'] = accuracy
            highest_acc['epoch'] = ep
        print(f'Accuracy: {accuracy}    Epoch:{ep}')

100%|██████████| 16/16 [00:03<00:00,  5.08it/s]


Epoch[0] training accuracy: 0.434 training loss: 1.559e+00 Base Lr: 1.00000e-04


100%|██████████| 5/5 [00:01<00:00,  4.69it/s]


Accuracy: 0.5133333333333333    Epoch:0


100%|██████████| 16/16 [00:03<00:00,  5.06it/s]


Epoch[1] training accuracy: 0.584 training loss: 1.201e+00 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.08it/s]


Epoch[2] training accuracy: 0.676 training loss: 1.079e+00 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.09it/s]


Epoch[3] training accuracy: 0.746 training loss: 9.481e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.10it/s]


Epoch[4] training accuracy: 0.766 training loss: 8.390e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.07it/s]


Epoch[5] training accuracy: 0.82 training loss: 7.673e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.08it/s]


Epoch[6] training accuracy: 0.846 training loss: 7.265e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.04it/s]


Epoch[7] training accuracy: 0.878 training loss: 6.817e-01 Base Lr: 1.00000e-04


  0%|          | 0/16 [00:00<?, ?it/s]Exception ignored in: <function _releaseLock at 0x7fa748196f80>
Traceback (most recent call last):
  File "/media/ntu/volume2/home/s121md302_07/anaconda3/envs/pytorch/lib/python3.7/logging/__init__.py", line 221, in _releaseLock
    def _releaseLock():
KeyboardInterrupt
100%|██████████| 16/16 [00:03<00:00,  4.96it/s]


Epoch[8] training accuracy: 0.894 training loss: 6.435e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  4.91it/s]


Epoch[9] training accuracy: 0.914 training loss: 6.149e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.00it/s]


Epoch[10] training accuracy: 0.914 training loss: 6.076e-01 Base Lr: 1.00000e-04


100%|██████████| 5/5 [00:01<00:00,  4.61it/s]


Accuracy: 0.8133333333333334    Epoch:10


100%|██████████| 16/16 [00:03<00:00,  5.06it/s]


Epoch[11] training accuracy: 0.93 training loss: 5.781e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.07it/s]


Epoch[12] training accuracy: 0.962 training loss: 5.172e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.09it/s]


Epoch[13] training accuracy: 0.968 training loss: 4.994e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.04it/s]


Epoch[14] training accuracy: 0.93 training loss: 5.649e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.00it/s]


Epoch[15] training accuracy: 0.936 training loss: 5.346e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.06it/s]


Epoch[16] training accuracy: 0.946 training loss: 5.249e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.00it/s]


Epoch[17] training accuracy: 0.954 training loss: 5.141e-01 Base Lr: 1.00000e-04


  0%|          | 0/16 [00:00<?, ?it/s]Exception ignored in: <function _releaseLock at 0x7fa748196f80>
Traceback (most recent call last):
  File "/media/ntu/volume2/home/s121md302_07/anaconda3/envs/pytorch/lib/python3.7/logging/__init__.py", line 221, in _releaseLock
    def _releaseLock():
KeyboardInterrupt
100%|██████████| 16/16 [00:03<00:00,  5.03it/s]


Epoch[18] training accuracy: 0.962 training loss: 4.751e-01 Base Lr: 1.00000e-04


100%|██████████| 16/16 [00:03<00:00,  5.00it/s]


Epoch[19] training accuracy: 0.982 training loss: 4.687e-01 Base Lr: 1.00000e-05


100%|██████████| 16/16 [00:03<00:00,  5.09it/s]


Epoch[20] training accuracy: 0.984 training loss: 4.597e-01 Base Lr: 1.00000e-05


100%|██████████| 5/5 [00:01<00:00,  4.55it/s]


Accuracy: 0.8    Epoch:20


100%|██████████| 16/16 [00:03<00:00,  5.02it/s]


Epoch[21] training accuracy: 0.988 training loss: 4.376e-01 Base Lr: 1.00000e-05


100%|██████████| 16/16 [00:03<00:00,  5.04it/s]


Epoch[22] training accuracy: 0.996 training loss: 4.304e-01 Base Lr: 1.00000e-05


 12%|█▎        | 2/16 [00:00<00:06,  2.30it/s]

In [None]:
# torch.save(train_model.state_dict(), output_dir + '/' + model_str + '_' + 'final' + '.pth')
print('highest_acc: {}  epoch: {}'.format(highest_acc['accuracy'], highest_acc['epoch']))