In [434]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch 
from torch import nn
import os
import torchvision
import torch.nn.functional as F
from torch import tensor
from torch.utils.data import DataLoader
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
from tqdm import tqdm
from matplotlib import pyplot as plt
import json

In [527]:
class CustomCNN(nn.Module):
    
    def __init__(self, num_classes):
        
        super(CustomCNN, self).__init__()
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.batch_norm1 = nn.BatchNorm2d(3)

        self.conv3 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.batch_norm2 = nn.BatchNorm2d(3)

        self.conv5 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.batch_norm3 = nn.BatchNorm2d(3)

        self.conv8 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=1)  # default stride is 2
        # self.batch_norm4 = nn.BatchNorm2d(3)

        self.conv11 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=1)  # default stride is 2
        self.batch_norm5 = nn.BatchNorm2d(3)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(3, num_classes)
        
    def forward(self, x):
        
        x1 = F.leaky_relu(self.conv1(x))
        x1 = self.pool1(x1)
        # x1 = self.batch_norm1(x1)
        
        x2 = F.leaky_relu(self.conv3(x1))
        x2 = self.pool2(x2)
        # x2 = self.batch_norm2(x2)

        x3 = F.leaky_relu(self.conv5(x2))
        x3 = self.pool3(x3)
        # x3 = self.batch_norm3(x3)

        x4 = F.leaky_relu(self.conv8(x3))
        x4 = self.pool4(x4)
        # x4 = self.batch_norm4(x4)

        x5 = F.leaky_relu(self.conv11(x4))
        x5 = self.pool5(x5)
        x5 = self.batch_norm5(x5)
        
        y = self.flatten(x5)
        y = F.leaky_relu(self.fc1(y))
        return x1, x2, x3, x4, x5, y

In [528]:
class CustomCNN_sieve(nn.Module):

    def __init__(self, shape1, shape2, shape3, shape4, shape5, num_classes):

        super().__init__()
        self.num_classes = num_classes
        self.core_network = CustomCNN(self.num_classes)
        
        self.shape1 = shape1
        self.shape2 = shape2
        self.shape3 = shape3
        self.shape4 = shape4
        self.shape5 = shape5
        
        self.sieve1 = nn.Linear(self.shape1[0] * self.shape1[1] * self.shape1[2], self.num_classes)
        self.sieve2 = nn.Linear(self.shape2[0] * self.shape2[1] * self.shape2[2], self.num_classes)
        self.sieve3 = nn.Linear(self.shape3[0] * self.shape3[1] * self.shape3[2], self.num_classes)
        self.sieve4 = nn.Linear(self.shape4[0] * self.shape4[1] * self.shape4[2], self.num_classes)
        self.sieve5 = nn.Linear(self.shape5[0] * self.shape5[1] * self.shape5[2], self.num_classes)
        self.flatten = nn.Flatten(start_dim = 1, end_dim = 3)

    def forward(self, x):

        s1, s2, s3, s4, s5, main_output = self.core_network(x)
        return self.sieve1(self.flatten(s1)), self.sieve2(self.flatten(s2)), self.sieve3(self.flatten(s3)), self.sieve4(self.flatten(s4)), self.sieve5(self.flatten(s5)), main_output
        

In [529]:
def training(model, dataloader, optimizer_core, optimizer_sieve, num_classes, lamb, epochs = 10):

    for epoch in tqdm(range(epochs)):

        epoch_loss_main = 0 
        epoch_loss_sieve = 0 
        iters = 0 
        
        for iter, (image, label) in enumerate(dataloader): 

            optimizer_core.zero_grad()
            optimizer_sieve.zero_grad()
            
            image, label = image.to(device), label.to(device)
            model.train()
            
            _, _, _, _, _, main_output = model(image)
            main_loss = F.binary_cross_entropy_with_logits(main_output, label)
            epoch_loss_main += main_loss.item()
            iters += 1
            
            main_loss.backward()

            # if(iter%30 == 0):
            #     for name, param in model_sieve.named_parameters():
            #         if param.grad is None:
            #             print(f"Layer {name} has no gradients.")
            #         else:
            #             print(f"Layer {name} gradient mean: {param.grad.mean()}")

            #     print('\n')
                        
            optimizer_core.step()
            model.core_network.eval()

            s1, s2, s3, s4, s5, _ = model(image)
            loss1 = F.cross_entropy(s1, label)
            loss2 = F.cross_entropy(s2, label)
            loss3 = F.cross_entropy(s3, label)
            loss4 = F.cross_entropy(s4, label)
            loss5 = F.cross_entropy(s5, label)

            loss = loss1 + loss2 + loss3 + loss4 + loss5
            epoch_loss_sieve += loss.item()
            # optimizer_model.zero_grad()
            # optimizer_weights.zero_grad()
            loss.backward()

            # if(iter%20 == 0):
            #     for name, param in model_sieve.named_parameters():
            #         if param.grad is None:
            #             print(f"Layer {name} has no gradients.")
            #         else:
            #             print(f"Layer {name} gradient mean: {param.grad.mean()}")

            #     print('\n')
                
            optimizer_sieve.step()
            # optimizer_weights.step()
            image,label = image.to('cpu'), label.to('cpu')
            
            
        print(f'The epoch is {epoch + 1}, the main loss is {epoch_loss_main/iters} and the sieve loss is {epoch_loss_sieve/iters}')

In [530]:
def forgetting(model, weights, dataloader, optimizer_core, num_classes, lamb, epochs = 10):

    for epoch in range(epochs):
        
        iters = 0 
        total_loss1 = 0
        total_loss2 = 0
        total_loss3 = 0
        total_loss4 = 0
        total_loss5 = 0
        
        for image, _ in dataloader: 

            image = image.to(device)
            label = torch.ones((image.shape[0] , num_classes), dtype = torch.float).to(device)
            label = label/num_classes 
            iters+=1 
            
            model.eval()
            s1, s2, s3, s4, s5, _ = model(image)
            optimizer_core.zero_grad() 
            
            model.sieve1.train()
            model.core_network.conv1.train()
            model.core_network.pool1.train()
            # model.core_network.batch_norm1.train()
            loss1 = weights[0]*F.binary_cross_entropy_with_logits(s1, label)
            total_loss1 += loss1.item()
            loss1.backward(retain_graph=True)
            model.eval()

            model.sieve2.train()
            model.core_network.conv3.train()
            model.core_network.pool2.train()
            # model.core_network.batch_norm2.train()
            loss2 = weights[1]*F.binary_cross_entropy_with_logits(s2, label)
            total_loss2 += loss2.item()
            loss2.backward(retain_graph=True)
            model.eval()

            model.sieve3.train()
            model.core_network.conv5.train()
            model.core_network.pool3.train()
            # model.core_network.batch_norm3.train()
            loss3 = weights[2]*F.binary_cross_entropy_with_logits(s3, label)
            total_loss3 += loss3.item()
            loss3.backward(retain_graph=True)
            model.eval()

            model.sieve4.train()
            model.core_network.conv8.train()
            model.core_network.pool4.train()
            # model.core_network.batch_norm4.train()
            loss4 = weights[3]*F.binary_cross_entropy_with_logits(s4, label)
            total_loss4 += loss4.item()
            loss4.backward(retain_graph=True)
            model.eval()

            model.sieve5.train()
            model.core_network.conv11.train()
            model.core_network.pool5.train()
            model.core_network.batch_norm5.train()
            loss5 = weights[4]*F.binary_cross_entropy_with_logits(s5, label)
            total_loss5 += loss5.item()
            loss5.backward(retain_graph=True)
            model.eval()

            optimizer_core.step()

            image = image.to('cpu')
            label = label.to('cpu')
            # loss = -(loss1 + loss2 + loss3 + loss4 + loss5) + lamb*torch.abs((weights[0] + weights[1] + weights[2] + weights[3] + weights[4]) - 1)
            # loss.backward()

            # optimizer_weights.zero_grad()
            # optimizer_weights.step()

        print(f"The epoch is {epoch}.\n loss1 is {total_loss1/iters} \n loss2 is {total_loss2/iters} \n loss3 is {total_loss3/iters} \n loss4 is {total_loss4/iters} \n loss5 is {total_loss5/iters}")

## Code to load cifarmnist dataset

In [614]:
labels_path = train_images_path + 'labels.json'
with open(labels_path, "r") as json_file:
    data = json.load(json_file)

In [407]:
data['Datas/Cifar10_Mnist_Lite/combined_image_999.png']

{'cifar10_label': 6, 'mnist_label': 5}

In [410]:
train_images_path = '/kaggle/input/cifar-10-mnist/Cifar10_Mnist_Lite/'
start_path = 'Datas/Cifar10_Mnist_Lite/'
cifar_dataset = []
for image_file in os.listdir(train_images_path):
    if image_file.endswith('.json'): continue
    img = tensor(torchvision.io.read_image(train_images_path + image_file), dtype = torch.float)
    label_cifar = data[start_path + image_file]['cifar10_label']
    label_mnist = data[start_path + image_file]['mnist_label']
    cifar_dataset.append((img,label_cifar,label_mnist))

  img = tensor(torchvision.io.read_image(train_images_path + image_file), dtype = torch.float)


## Code to load colormnist dataset

In [668]:
# green 0 and red 1 in training 
# red 0 and green 1 in testing

In [669]:
train_0_path = '/kaggle/input/colored-mnist-dataset/colorized-MNIST-master/training/0/'
train_1_path = '/kaggle/input/colored-mnist-dataset/colorized-MNIST-master/training/1/'

In [670]:
train_dataset = []
for image_file in os.listdir(train_0_path):
    img = tensor(torchvision.io.read_image(train_0_path + image_file), dtype = torch.float)
    image_corner = img[:, :2, :2]
    if(torch.mean(image_corner, dim = (1,2))[1] > 0):
        train_dataset.append((img, torch.tensor([1,0], dtype = torch.float)))
        
for image_file in os.listdir(train_1_path):
    img = tensor(torchvision.io.read_image(train_1_path + image_file), dtype = torch.float)
    image_corner = img[:, :2, :2]
    if(torch.mean(image_corner, dim = (1,2))[0] > 0):
        train_dataset.append((img, torch.tensor([0,1], dtype = torch.float)))

  img = tensor(torchvision.io.read_image(train_0_path + image_file), dtype = torch.float)
  img = tensor(torchvision.io.read_image(train_1_path + image_file), dtype = torch.float)


In [671]:
test_0_path = '/kaggle/input/colored-mnist-dataset/colorized-MNIST-master/testing/0/'
test_1_path = '/kaggle/input/colored-mnist-dataset/colorized-MNIST-master/testing/1/'

In [672]:
test_dataset = []
for image_file in os.listdir(test_0_path):
    img = tensor(torchvision.io.read_image(test_0_path + image_file), dtype = torch.float)
    image_corner = img[:, :2, :2]
    if(torch.mean(image_corner, dim = (1,2))[0] > 0):
        test_dataset.append((img, torch.tensor([1,0], dtype = torch.float)))

for image_file in os.listdir(test_1_path):
    img = tensor(torchvision.io.read_image(test_1_path + image_file), dtype = torch.float)
    image_corner = img[:, :2, :2]
    if(torch.mean(image_corner, dim = (1,2))[1] > 0):
        test_dataset.append((img, torch.tensor([0,1], dtype = torch.float)))

  img = tensor(torchvision.io.read_image(test_0_path + image_file), dtype = torch.float)
  img = tensor(torchvision.io.read_image(test_1_path + image_file), dtype = torch.float)


In [673]:
valid_dataset = train_dataset[int((len(train_dataset)) * 0.9):]
train_dataset = train_dataset[:int((len(train_dataset)) * 0.9)]

In [674]:
sample_image = train_dataset[0][0]

In [675]:
model = CustomCNN(2)

In [676]:
model.eval()
s1, s2, s3, s4, s5, main_out = model(sample_image.unsqueeze(0))

In [677]:
shape1 = s1.shape[1:]
shape2 = s2.shape[1:] 
shape3 = s3.shape[1:] 
shape4 = s4.shape[1:] 
shape5 = s5.shape[1:]

In [678]:
model_sieve = CustomCNN_sieve(shape1, shape2, shape3, shape4, shape5, 2).to(device)

In [679]:
weights = tensor([0.5, 0.3, 0.1, 0.07, 0.03], device = device, requires_grad = False)

In [662]:
dataloader = DataLoader(train_dataset, batch_size = 32, shuffle = True)
opt_core = torch.optim.Adam(model_sieve.parameters(), lr = 1e-3)
opt_sieve = torch.optim.Adam(
    list(model_sieve.sieve1.parameters()) +
    list(model_sieve.sieve2.parameters()) +
    list(model_sieve.sieve3.parameters()) +
    list(model_sieve.sieve4.parameters()) +
    list(model_sieve.sieve5.parameters()),
    lr=1e-3
)

In [663]:
total_iterations = 1

In [664]:
def get_accuracy(test_dataset, model):
    model.eval()
    correct = 0
    for i in range(len(test_dataset)):
        image = test_dataset[i][0].to(device).unsqueeze(0)
        label = test_dataset[i][1].to(device)
        
        _, _, _ ,_,_, main_pred = model(image)
        correct += torch.argmax(main_pred) == torch.argmax(test_dataset[i][1])
        
        image = image.to('cpu')
        label = label.to('cpu')
    return correct.item()/len(test_dataset) * 100 

In [665]:
model.train()
for i in range(total_iterations):
    training(model_sieve, dataloader, opt_core, opt_sieve, 2, 1, 10)
    # forgetting(model_sieve, weights, dataloader, opt_core, 2, 1, 5)
    print(f"The iteration is {i} and the validation set accuracy is {get_accuracy(valid_dataset, model_sieve)}")

 20%|██        | 2/10 [00:00<00:01,  7.38it/s]

The epoch is 1, the main loss is 0.7827012538909912 and the sieve loss is 12.432634970721077
The epoch is 2, the main loss is 0.6345311263028313 and the sieve loss is 3.999729465035831


 40%|████      | 4/10 [00:00<00:00,  7.46it/s]

The epoch is 3, the main loss is 0.5803164839744568 and the sieve loss is 2.7727481056662167
The epoch is 4, the main loss is 0.48568326760740843 and the sieve loss is 2.6821666885824764


 60%|██████    | 6/10 [00:00<00:00,  7.47it/s]

The epoch is 5, the main loss is 0.4695095875683953 and the sieve loss is 2.3131158772636864
The epoch is 6, the main loss is 0.46239835900418896 and the sieve loss is 1.9722502582213457


 80%|████████  | 8/10 [00:01<00:00,  7.49it/s]

The epoch is 7, the main loss is 0.45882128091419444 and the sieve loss is 1.5602403528550093
The epoch is 8, the main loss is 0.4535786807537079 and the sieve loss is 1.2505750726251041


100%|██████████| 10/10 [00:01<00:00,  7.47it/s]

The epoch is 9, the main loss is 0.4475888662478503 and the sieve loss is 1.0116958863594954
The epoch is 10, the main loss is 0.44402795679428997 and the sieve loss is 0.8520476677838493





The iteration is 0 and the validation set accuracy is 100.0


In [666]:
get_accuracy(train_dataset, model_sieve)

100.0

In [667]:
get_accuracy(test_dataset, model_sieve)

9.256661991584853

In [None]:
image_tensor = next(train_iter)[0]
plt.imshow(image_tensor.permute(1,2,0))
image_tensor = image_tensor[:, :5, :5]
print(image_tensor)