In [1]:
# imports
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision.models import resnet101
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:
from iirc.datasets_loader import get_lifelong_datasets
from iirc.definitions import PYTORCH, IIRC_SETUP
from iirc.utils.download_cifar import download_extract_cifar100

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#### Pre-trained feature extractor (imagenet)

In [3]:
# gmm
from sklearn.mixture import GaussianMixture

gmm = GaussianMixture(n_components=2, covariance_type="full", random_state=0)

# dataset
class Dataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


In [3]:
# cifar100 dataset
from torchvision.datasets import CIFAR100
train_dataset = CIFAR100(root="data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
test_dataset = CIFAR100(root="data", train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))

# dataloader
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=True, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# import torchvision.transforms as transforms

# essential_transforms_fn = transforms.ToTensor()
# augmentation_transforms_fn = transforms.Compose([
#     transforms.RandomCrop(32,padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor()
# ])

# dataset_splits, tasks, class_names_to_idx = \
#     get_lifelong_datasets(dataset_name = "iirc_cifar100",
#                           dataset_root = "./data", # the imagenet folder (where the train and val folders reside, or the parent directory of cifar-100-python folder
#                           setup = IIRC_SETUP,
#                           framework = PYTORCH,
#                           tasks_configuration_id = 0,
#                           essential_transforms_fn = essential_transforms_fn,
#                           augmentation_transforms_fn = augmentation_transforms_fn,
#                           joint = False
#                          )

# # print(len(tasks))
# n_classes_per_task = []
# for task in tasks:
#     n_classes_per_task.append(len(task))
# n_classes_per_task = np.array(n_classes_per_task)

# # lifelong_datasets['train'].choose_task(2)
# # print(list(zip(*lifelong_datasets['train']))[1])
# for i in dataset_splits:
#     print(i)


Creating iirc_cifar100
Setup used: IIRC
Using PyTorch
Dataset created
train
intask_valid
posttask_valid
test


In [6]:
from torchmetrics.classification import MulticlassRecall
from torchmetrics import ClasswiseWrapper

def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    # class_recall = ClasswiseWrapper(MulticlassRecall(num_classes=seen_classes, average="macro"), None)
    # class_recall = class_recall.to(device)
    recall = 0
        
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)  
            outputs = model(images)

            predicted = torch.argmax(outputs.data, 1)
            # print(predicted, labels)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            # print(class_recall(predicted, labels))

    print("Accuracy of the model on the test images: {} %".format(100 * correct / total))
    # print("Recall of the model on the test images: {} %".format(class_recall))

    return (correct / total) * 100

In [50]:
def train(model, train_loader, test_loader, criterion, optimizer,scheduler,  epochs):
    loss_list = []
    acc_list = []
     
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, labels) in enumerate(train_loader):
            data = data.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            output = model(data)

            
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            # scheduler.step()

            if batch_idx % 10 == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
                )
        loss_list.append(loss.item())
        acc = test(model, test_loader)
        acc_list.append(acc)

    return loss_list, acc_list

In [68]:
# def set_parameter_requires_grad(model, feature_extracting):
#     if feature_extracting:
#         for param in model.parameters():
#             param.requires_grad = False

# class MultilabelClassifier(nn.Module):
#     def __init__(self, num_classes):
#         super().__init__()
#         self.resnet = resnet101(weights='DEFAULT')
        
#         self.model_wo_fc = nn.Sequential(*(list(self.resnet.children())[:-1]))
#         self.num_ftrs = self.resnet.fc.in_features
        
#         self.fc = nn.Linear(self.num_ftrs, num_classes)
        
#     def forward(self, x):
#         x = self.model_wo_fc(x)
#         x = torch.flatten(x, 1)
#         x = torch.sigmoid(self.fc(x))
#         return x

In [8]:
criterion = nn.CrossEntropyLoss() # as output is sigmoidless
import torchvision
# # get dataset corresponding to each split
# train_data = dataset_splits["train"]
# intask_val_data = dataset_splits["intask_valid"]
# posttask_val_data = dataset_splits["posttask_valid"]
# test_data = dataset_splits["test"]

# batch_size = 32
# num_classes = len(class_names_to_idx)

# resnet = MultilabelClassifier(num_classes)

model = resnet101(weights=torchvision.models.ResNet101_Weights.IMAGENET1K_V2)

model.fc = nn.Linear(model.fc.in_features, 100)

model = torch.nn.DataParallel(model)
model = model.to(device)
optimizer_ft = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

train(model=model, train_loader=train_loader, test_loader=test_loader, criterion=criterion, optimizer=optimizer_ft, epoch=20)


test(model, test_loader)
# # parallel
# # resnet = torch.nn.DataParallel(resnet)
# resnet = resnet.to(device)

# seen_classes = 0
# # initialize data to train on first task
# for task in range(len(tasks[:4])):
#     print("+++++TRAINING ON TASK {}++++++".format(task))
#     train_data.choose_task(task)
#     intask_val_data.choose_task(task)
#     posttask_val_data.choose_task(task)
    
#     trainloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
#     InTask_valloader = torch.utils.data.DataLoader(intask_val_data, batch_size=batch_size, shuffle=True, num_workers=2)
#     PostTask_valloader = torch.utils.data.DataLoader(posttask_val_data, batch_size=batch_size, shuffle=True, num_workers=2)
    
#     seen_classes += n_classes_per_task[task]
#     # print(tasks[task])

    
# #     test_model(resnet, PostTask_valloader,seen_classes, mode=1)

# # resnet = train_model(resnet, dataloader_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))

Accuracy of the model on the test images: 1.02 %
Accuracy of the model on the test images: 34.7 %
Accuracy of the model on the test images: 48.0 %


48.0

In [15]:
loss_list, acc_list = train(model, train_loader, test_loader, criterion, optimizer_ft, epochs=200)

# plot loss and accuracy
plt.subplot(2, 1, 1)
plt.title("Accuracy")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.plot(loss_list)

plt.subplot(2, 1, 2)
plt.title("Accuracy")
plt.xlabel("Iterations")
plt.ylabel("Accuracy")
plt.plot(acc_list)

plt.tight_layout()
plt.show()


Accuracy of the model on the test images: 30.5 %
Accuracy of the model on the test images: 58.15 %
Accuracy of the model on the test images: 60.38 %
Accuracy of the model on the test images: 60.59 %
Accuracy of the model on the test images: 61.29 %
Accuracy of the model on the test images: 61.54 %
Accuracy of the model on the test images: 62.0 %
Accuracy of the model on the test images: 61.79 %
Accuracy of the model on the test images: 62.58 %
Accuracy of the model on the test images: 62.64 %
Accuracy of the model on the test images: 62.47 %
Accuracy of the model on the test images: 63.15 %
Accuracy of the model on the test images: 62.99 %
Accuracy of the model on the test images: 63.12 %
Accuracy of the model on the test images: 62.92 %
Accuracy of the model on the test images: 63.3 %
Accuracy of the model on the test images: 63.24 %
Accuracy of the model on the test images: 63.12 %
Accuracy of the model on the test images: 63.17 %
Accuracy of the model on the test images: 63.17 %
Acc

KeyboardInterrupt: 

In [44]:
# save model
torch.save(model.state_dict(), './feature_extractor/resnet101_70epochs.pth')

In [55]:
model = resnet101()
model.fc = nn.Linear(model.fc.in_features, 100)
# load saved model

model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('./feature_extractor/resnet101_70epochs.pth'))
optimizer_ft = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
model = model.to(device)
# add scheduler
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

# train(model, train_loader, test_loader, criterion, optimizer_ft, scheduler=None, epochs=200)
model.eval()
test(model, test_loader)

Accuracy of the model on the test images: 62.96 %


62.96000000000001

In [35]:
model = resnet101()
model.fc = nn.Linear(model.fc.in_features, 100)
# load saved model

model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('./feature_extractor/resnet101_70epochs.pth'))
model = model.to(device)
model.eval()
test(model, test_loader)

Accuracy of the model on the test images: 62.96 %


62.96000000000001

In [39]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

__all__ = ['ResNetCIFAR']


def _weights_init(m):
    if isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight, nonlinearity="relu")
    elif isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight, nonlinearity="sigmoid")


class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='B', relu_output=True):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu_output = relu_output

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], [0, 0, 0, 0, planes // 4, planes // 4], "constant",
                                                  0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        if self.relu_output:
            out = F.relu(out)
        return out


class ResNetCIFARModule(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, relu_last_hidden=False):
        super(ResNetCIFARModule, self).__init__()
        self.in_planes = 16
        self.latent_dim = 64
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], 1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], 2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], 2, relu_last_hidden)
        self.output_layer = nn.Linear(self.latent_dim, self.num_classes)

    def _make_layer(self, block, planes, num_blocks, stride, relu_last_hidden=True):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(len(strides)):
            if i == (len(strides) - 1):
                layers.append(block(self.in_planes, planes, strides[i], relu_output=relu_last_hidden))
            else:
                layers.append(block(self.in_planes, planes, strides[i]))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = F.avg_pool2d(x, x.size()[3])
        x = x.view(x.size(0), -1)
        out = self.output_layer(x)
        return out


class ResNetCIFAR(nn.Module):
    def __init__(self, num_classes=10, num_layers=20, relu_last_hidden=False):
        super(ResNetCIFAR, self).__init__()
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.relu_last_hidden = relu_last_hidden

        if num_layers not in [20, 32, 44, 56, 110]:
            raise ValueError("For ResNetCifar, choose a number of layers out of 20, 32, 44, 56, and 110")
        elif num_layers == 20:
            self.model = ResNetCIFARModule(BasicBlock, [3, 3, 3], num_classes, relu_last_hidden)
        elif num_layers == 32:
            self.model = ResNetCIFARModule(BasicBlock, [5, 5, 5], num_classes, relu_last_hidden)
        elif num_layers == 44:
            self.model = ResNetCIFARModule(BasicBlock, [7, 7, 7], num_classes, relu_last_hidden)
        elif num_layers == 56:
            self.model = ResNetCIFARModule(BasicBlock, [9, 9, 9], num_classes, relu_last_hidden)
        elif num_layers == 110:
            self.model = ResNetCIFARModule(BasicBlock, [18, 18, 18], num_classes, relu_last_hidden)

        self.apply(_weights_init)

    def forward(self, input_):
        return self.model(input_)


In [None]:
model  = ResNetCIFAR(num_classes=100, num_layers=110, relu_last_hidden=False)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('./feature_extractor/resnet110_70epochs.pth'))
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_ft = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
# add scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=50, gamma=0.1)

train(model, train_loader, test_loader, criterion, optimizer_ft, scheduler=scheduler, epochs=200)

In [48]:
torch.save(model.state_dict(), './feature_extractor/resnet110_200epochs.pth')

In [51]:
model  = ResNetCIFAR(num_classes=100, num_layers=110, relu_last_hidden=False)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('./feature_extractor/resnet110_200epochs.pth'))
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_ft = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# add scheduler
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=50, gamma=0.1)

train(model, train_loader, test_loader, criterion, optimizer_ft, scheduler=None, epochs=200)

Accuracy of the model on the test images: 22.31 %
Accuracy of the model on the test images: 28.48 %
Accuracy of the model on the test images: 29.11 %
Accuracy of the model on the test images: 28.66 %
Accuracy of the model on the test images: 29.09 %
Accuracy of the model on the test images: 26.86 %
Accuracy of the model on the test images: 29.39 %


KeyboardInterrupt: 

In [56]:


torch.save(model.module.state_dict(), "./feature_extractor/resnet101_70epochs_any.pth")

# save module state dict

In [58]:
model = resnet101()
model.fc = nn.Linear(2048, 100)
model.load_state_dict(torch.load('./feature_extractor/resnet101_70epochs_any.pth'))
model = model.to(device)
model.eval()
test(model, test_loader)

Accuracy of the model on the test images: 62.96 %


62.96000000000001