In [61]:
import torch as tc
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, random_split


import torchvision as tv
from torchvision import transforms as T
from torchvision import datasets
from torchvision import models

from torchmetrics import Accuracy

from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

import copy

In [62]:
device = 'cuda' if tc.cuda.is_available() else 'cpu'
device

'cuda'

In [63]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [64]:
def train_one_epoch(model, train_loader, loss_func, optimizer,shedular, epoch):
    model.train()

    train_loss=AverageMeter()
    train_acc=Accuracy(task='multiclass', num_classes=10).to(device)

    with tqdm(train_loader, unit='batch') as tepoch:
        for x, y in tepoch:
            if epoch is not None:
                tepoch.set_description(f'Epoch={epoch}')

            optimizer.zero_grad()

            x = x.to(device)
            y = y.to(device)

            yp = model(x)

            loss = loss_func(yp.squeeze(), y)
            loss.backward()
            optimizer.step()
            shedular.step()


            train_loss.update(loss.item())
            train_acc(yp.squeeze(), y.int())

            tepoch.set_postfix(loss=train_loss.avg, Accuracy=train_acc.compute().item() * 100)

    return model, train_loss.avg, train_acc.compute().item() * 100

In [65]:
def validation(model, valid_loaedr, loss_func):
    model.eval()

    valid_loss=AverageMeter()
    valid_acc=Accuracy(task='multiclass', num_classes=10).to(device)

    for x, y in valid_loaedr:

        x=x.to(device)
        y=y.to(device)

        yp=model(x)
        loss=loss_func(yp.squeeze(), y)

        valid_loss.update(loss.item())
        valid_acc(yp.squeeze(), y)

    print(f'valid loss={valid_loss.avg:.4}, accuracy={valid_acc.compute().item() * 100 :.4}')
    print()

    return valid_loss.avg, valid_acc.compute().item() * 100

In [66]:
train_transform = T.Compose([T.RandomCrop(32, 4),
                             T.RandomHorizontalFlip(),
                             T.ToTensor(),
                             T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

test_transform = T.Compose([T.ToTensor(),
                            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

In [67]:
train_dataset = datasets.CIFAR10(root =  '/home/ahmadrezabaqerzadeh/datasets/', train = True, download = True, transform = train_transform)
test_dataset  = datasets.CIFAR10(root = '/home/ahmadrezabaqerzadeh/datasets/', train = False, download = True, transform = test_transform)

Files already downloaded and verified
Files already downloaded and verified


In [68]:
train_loader = DataLoader(train_dataset, batch_size = 512, shuffle = True, num_workers = 2)
test_loader  = DataLoader(test_dataset, batch_size = 100, shuffle = False, num_workers = 2)

In [71]:
class EnsembleModels(nn.Module):

    def __init__(self):
        super(EnsembleModels, self).__init__()
        with tc.no_grad():
            self.model0 = models.mobilenet_v2()
            self.model0.classifier[1] = nn.Linear(1280, 10)
            self.model0.load_state_dict(tc.load('/content/TeacherModel.pt', map_location = device))
            self.model0 = self.model0.requires_grad_(False)

            self.model1 = models.mobilenet_v2()
            self.model1.classifier[1] = nn.Linear(1280, 10)
            self.model1.load_state_dict(tc.load('/content/SDModel.pt', map_location = device))
            self.model1 = self.model1.requires_grad_(False)

            self.model2 = models.mobilenet_v2()
            self.model2.classifier[1] = nn.Linear(1280, 10)
            self.model2.load_state_dict(tc.load('/content/SDModel.pt', map_location = device))
            self.model2 = self.model2.requires_grad_(False)



        self.fc0 = nn.Linear(10, 10)

    def forward(self, x):

        y0 = self.model0(x)
        y1 = self.model1(x)
        y2 = self.model2(x)

        comb = tc.stack([y0, y1, y2], dim = 0)
        out  = tc.mean(comb, dim = 0)
        out  = out.squeeze()

        out  = self.fc0(out)

        return out

In [72]:
x, y = next(iter(train_loader))

In [73]:
model = EnsembleModels().to(device)

In [74]:
def num_parameters(model):
    n = sum([tc.numel(p) for p in model.parameters() if p.requires_grad])
    return str(n)

In [75]:
num_parameters(model)

'110'

In [76]:
model(x.to(device)).shape

torch.Size([512, 10])

In [77]:
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9, weight_decay = 5e-4)
shedular  = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 3000, eta_min = 1e-8)

In [78]:
loss_func = nn.CrossEntropyLoss()

In [79]:
loss_train_hist, acc_train_hist = [], []
loss_valid_hist, acc_valid_hist = [], []

best_acc   = 0
best_model = [0]

In [80]:
start, end = 0, 10
for epoch in range(start, end):

    model, train_loss, train_acc = train_one_epoch(model, train_loader, loss_func, optimizer, shedular, epoch)
    val_loss, val_acc            = validation(model, test_loader, loss_func)

    loss_train_hist.append(train_loss)
    acc_train_hist.append(train_acc)

    loss_valid_hist.append(val_loss)
    acc_valid_hist.append(val_acc)

    if val_acc>best_acc:
        best_model = model
        best_acc  = val_acc
        print('model saved!')

Epoch=0: 100%|██████████| 98/98 [00:23<00:00,  4.22batch/s, Accuracy=89, loss=0.451]


valid loss=0.3539, accuracy=88.94

model saved!


Epoch=1: 100%|██████████| 98/98 [00:23<00:00,  4.22batch/s, Accuracy=99, loss=0.0503]


valid loss=0.3653, accuracy=88.95

model saved!


Epoch=2: 100%|██████████| 98/98 [00:22<00:00,  4.35batch/s, Accuracy=99.1, loss=0.0417]


valid loss=0.3753, accuracy=89.02

model saved!


Epoch=3: 100%|██████████| 98/98 [00:21<00:00,  4.47batch/s, Accuracy=99.1, loss=0.0405]


valid loss=0.3835, accuracy=89.02



Epoch=4: 100%|██████████| 98/98 [00:21<00:00,  4.49batch/s, Accuracy=99.1, loss=0.0377]


valid loss=0.3914, accuracy=88.97



Epoch=5: 100%|██████████| 98/98 [00:22<00:00,  4.40batch/s, Accuracy=99.1, loss=0.0367]


valid loss=0.3961, accuracy=89.19

model saved!


Epoch=6: 100%|██████████| 98/98 [00:22<00:00,  4.27batch/s, Accuracy=99.2, loss=0.034]


valid loss=0.4016, accuracy=89.12



Epoch=7: 100%|██████████| 98/98 [00:22<00:00,  4.28batch/s, Accuracy=99.1, loss=0.0334]


valid loss=0.407, accuracy=89.1



Epoch=8: 100%|██████████| 98/98 [00:23<00:00,  4.26batch/s, Accuracy=99.1, loss=0.0327]


valid loss=0.4095, accuracy=89.07



Epoch=9: 100%|██████████| 98/98 [00:22<00:00,  4.29batch/s, Accuracy=99.1, loss=0.0319]


valid loss=0.4139, accuracy=88.99

