In [1]:
%matplotlib inline
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
import numpy as np
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision import transforms
import torch.optim as optim
import time
import tqdm as tqdm
from torch.autograd import Variable

batch_size = 256
transform = transforms.Compose([transforms.ToTensor(),
                                # expand chennel from 1 to 3 to fit 
                                # ResNet pretrained model
                                transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                                ]) 

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
    shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size,
    shuffle=True, num_workers=0)

for x, y in train_loader:
    print(x.shape, y.shape)
    break

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


torch.Size([256, 3, 28, 28]) torch.Size([256])




In [2]:
# print(models.resnet18())
class ResNetFeatrueExtractor18(nn.Module):
    def __init__(self, pretrained = True):
        super(ResNetFeatrueExtractor18, self).__init__()
        model_resnet18 = models.resnet18(pretrained=pretrained)
        self.conv1 = model_resnet18.conv1
        self.bn1 = model_resnet18.bn1
        self.relu = model_resnet18.relu
        self.maxpool = model_resnet18.maxpool
        self.layer1 = model_resnet18.layer1
        self.layer2 = model_resnet18.layer2
        self.layer3 = model_resnet18.layer3
        self.layer4 = model_resnet18.layer4
        self.avgpool = model_resnet18.avgpool

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        return x

class ResClassifier(nn.Module):
    def __init__(self, dropout_p=0.5): #in_features=512
        super(ResClassifier, self).__init__()        
        self.fc = nn.Linear(512, 10)
    def forward(self, x):       
        out = self.fc(x)
        return out

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('Linear') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.01)


def evaluate_accuracy(data_iter, netG, netF):
    """Evaluate accuracy of a model on the given data set."""
    acc_sum,n = 0,0
    for (imgs, labels) in data_iter:
        # send data to the GPU if cuda is availabel
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            labels = labels.cuda()
        netG.eval()
        netF.eval()
        with torch.no_grad():
            labels = labels.long()
            acc_sum += torch.sum((torch.argmax(netF(netG(imgs)), dim=1) == labels)).float()
            n += labels.shape[0]
    return acc_sum.item()/n

## **Training using Pre-trained model**

In [8]:
netG = ResNetFeatrueExtractor18()
netF = ResClassifier()

if torch.cuda.is_available():
    netG = netG.cuda()
    netF = netF.cuda()

# setting up optimizer for both feature generator G and classifier F.
opt_g = optim.SGD(netG.parameters(), lr=0.01, weight_decay=0.0005)
opt_f = optim.SGD(netF.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)

# loss function
criterion = nn.CrossEntropyLoss()

for epoch in range(0, 10):
    n, start = 0, time.time()
    train_l_sum = torch.tensor([0.0], dtype=torch.float32)
    train_acc_sum = torch.tensor([0.0], dtype=torch.float32)
    for i, (imgs, labels) in tqdm.tqdm(enumerate(iter(train_loader))):
        netG.train()
        netF.train()
        imgs = Variable(imgs)
        labels = Variable(labels)     
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            labels = labels.cuda()
            train_l_sum = train_l_sum.cuda()
            train_acc_sum = train_acc_sum.cuda()

        opt_g.zero_grad()
        opt_f.zero_grad()

        # extracted feature
        bottleneck = netG(imgs)     
        
        # predicted labels
        label_hat = netF(bottleneck)

        # loss function
        loss= criterion(label_hat, labels)
        loss.backward()
        opt_g.step()
        opt_f.step()
        
        # calcualte training error
        netG.eval()
        netF.eval()
        labels = labels.long()
        train_l_sum += loss.float()
        train_acc_sum += (torch.sum((torch.argmax(label_hat, dim=1) == labels))).float()
        n += labels.shape[0]
    test_acc = evaluate_accuracy(iter(test_loader), netG, netF) 
    print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'\
        % (epoch + 1, train_l_sum/n, train_acc_sum/n, test_acc, time.time() - start))


235it [00:19, 12.03it/s]
2it [00:00, 17.54it/s]

epoch 1, loss 0.0008, train acc 0.936, test acc 0.983, time 20.6 sec


235it [00:20, 11.44it/s]
2it [00:00, 17.05it/s]

epoch 2, loss 0.0002, train acc 0.988, test acc 0.986, time 21.6 sec


235it [00:20, 11.46it/s]
2it [00:00, 17.76it/s]

epoch 3, loss 0.0001, train acc 0.993, test acc 0.988, time 21.6 sec


235it [00:19, 11.85it/s]
2it [00:00, 18.02it/s]

epoch 4, loss 0.0001, train acc 0.995, test acc 0.988, time 20.9 sec


235it [00:19, 11.77it/s]
2it [00:00, 16.58it/s]

epoch 5, loss 0.0000, train acc 0.997, test acc 0.989, time 21.1 sec


235it [00:20, 11.57it/s]
2it [00:00, 17.43it/s]

epoch 6, loss 0.0000, train acc 0.998, test acc 0.989, time 21.4 sec


235it [00:20, 11.66it/s]
2it [00:00, 16.08it/s]

epoch 7, loss 0.0000, train acc 0.999, test acc 0.989, time 21.2 sec


235it [00:20, 11.73it/s]
2it [00:00, 17.21it/s]

epoch 8, loss 0.0000, train acc 0.999, test acc 0.989, time 21.1 sec


235it [00:20, 11.65it/s]
2it [00:00, 17.52it/s]

epoch 9, loss 0.0000, train acc 0.999, test acc 0.990, time 21.3 sec


235it [00:20, 11.66it/s]


epoch 10, loss 0.0000, train acc 1.000, test acc 0.989, time 21.2 sec


## **Training without Pre-trained model**

In [3]:
# setting pretrained to False. The rest is the same
netG = ResNetFeatrueExtractor18(pretrained=False)
netF = ResClassifier()

if torch.cuda.is_available():
    netG = netG.cuda()
    netF = netF.cuda()

opt_g = optim.SGD(netG.parameters(), lr=0.01, weight_decay=0.0005)
opt_f = optim.SGD(netF.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)

criterion = nn.CrossEntropyLoss()

for epoch in range(0, 10):
    n, start = 0, time.time()
    train_l_sum = torch.tensor([0.0], dtype=torch.float32)
    train_acc_sum = torch.tensor([0.0], dtype=torch.float32)
    for i, (imgs, labels) in tqdm.tqdm(enumerate(iter(train_loader))):
        netG.train()
        netF.train()
        imgs = Variable(imgs)
        labels = Variable(labels)     
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            labels = labels.cuda()
            train_l_sum = train_l_sum.cuda()
            train_acc_sum = train_acc_sum.cuda()

        opt_g.zero_grad()
        opt_f.zero_grad()

        bottleneck = netG(imgs)     
        
        label_hat = netF(bottleneck)

        # loss function
        loss= criterion(label_hat, labels)
        loss.backward()
        opt_g.step()
        opt_f.step()
        
        # calcualte training error
        netG.eval()
        netF.eval()
        labels = labels.long()
        train_l_sum += loss.float()
        train_acc_sum += (torch.sum((torch.argmax(label_hat, dim=1) == labels))).float()
        n += labels.shape[0]
    test_acc = evaluate_accuracy(iter(test_loader), netG, netF) 
    print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'\
        % (epoch + 1, train_l_sum/n, train_acc_sum/n, test_acc, time.time() - start))


235it [00:18, 12.51it/s]
2it [00:00, 19.50it/s]

epoch 1, loss 0.0009, train acc 0.933, test acc 0.974, time 19.8 sec


235it [00:19, 12.32it/s]
2it [00:00, 19.02it/s]

epoch 2, loss 0.0002, train acc 0.984, test acc 0.981, time 20.1 sec


235it [00:19, 12.03it/s]
2it [00:00, 18.88it/s]

epoch 3, loss 0.0001, train acc 0.993, test acc 0.983, time 20.4 sec


235it [00:19, 12.27it/s]
2it [00:00, 18.43it/s]

epoch 4, loss 0.0000, train acc 0.998, test acc 0.984, time 20.1 sec


235it [00:19, 12.36it/s]
2it [00:00, 19.07it/s]

epoch 5, loss 0.0000, train acc 0.999, test acc 0.985, time 19.9 sec


235it [00:19, 12.36it/s]
2it [00:00, 19.98it/s]

epoch 6, loss 0.0000, train acc 1.000, test acc 0.985, time 19.9 sec


235it [00:19, 12.23it/s]
2it [00:00, 19.36it/s]

epoch 7, loss 0.0000, train acc 1.000, test acc 0.986, time 20.2 sec


235it [00:19, 12.27it/s]
2it [00:00, 19.37it/s]

epoch 8, loss 0.0000, train acc 1.000, test acc 0.986, time 20.0 sec


235it [00:19, 12.35it/s]
2it [00:00, 17.92it/s]

epoch 9, loss 0.0000, train acc 1.000, test acc 0.986, time 20.1 sec


235it [00:19, 12.35it/s]


epoch 10, loss 0.0000, train acc 1.000, test acc 0.987, time 20.1 sec


The model without pre-train is slightly wores than the model with pre-train. But as the MNIST dataset is pretty simple, the difference is not significant.