In [None]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader,random_split,Dataset
import torch.optim as optim
from tqdm import tqdm
from training_utils import *

In [None]:
seed = 43
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

torch.backends.cudnn.deterministic = True

In [None]:
trainset = datasets.CIFAR10(root='./data/', train=True, download=False, transform=transforms.ToTensor())
testset = datasets.CIFAR10(root='./data/', train=False, download=False, transform=transforms.ToTensor())

labels_list = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
non_animal = [0,1,8,9]
device = 'cuda'

In [None]:
class NewDataset(Dataset):
    
    def __init__(self,data,transform=None):
        self.data = data
        
    def __len__(self):
        return len(self.data)    
    
    def __getitem__(self,idx):
        image = self.data[idx][0]
        label1 = self.data[idx][1]          #original label 
        label2 = 0 if self.data[idx][1] in non_animal else 1       #animal or non-animal
        return image, label1, label2

In [None]:
new_trainset = NewDataset(trainset,non_animal)
new_testset = NewDataset(testset,non_animal)

train_set, valid_set = random_split(new_trainset,[int(len(new_trainset)*0.9), int(len(new_trainset)*0.1)],
                                  generator=torch.Generator().manual_seed(0))

train_loader = DataLoader(train_set, batch_size=100, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=100, shuffle=True)
test_loader = DataLoader(new_testset, batch_size=100, shuffle=True)

In [None]:
class MTL_Net(nn.Module):
    def __init__(self, input_channel, num_class):
        super(MTL_Net,self).__init__()
        
        self.classes = num_class
        
        self.conv1 = nn.Conv2d(in_channels=input_channel,out_channels=8,kernel_size=3,stride=1)
        self.conv2 = nn.Conv2d(in_channels=8,out_channels=16,kernel_size=3,stride=1)
        self.fc1 = nn.Linear(64, 256)
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256,128)
        self.dropout2 = nn.Dropout(0.3)
        
        self.fc3 = nn.Linear(128, self.classes[0])
        self.fc4 = nn.Linear(128, self.classes[1])
        
    def forward(self, x):
        
        x = F.max_pool2d(F.relu(self.conv1(x)),kernel_size=3)
        x = F.max_pool2d(F.relu(self.conv2(x)),kernel_size=3)
        x = F.relu(self.fc1(x.reshape(-1,x.shape[1] * x.shape[2]*x.shape[3])))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x1 = self.fc3(x)
        x2 = self.fc4(x)
        
        return x1,x2

In [None]:
class Net(nn.Module):
    def __init__(self, input_channel, num_class):
        super(Net,self).__init__()
        
        self.classes = num_class
        
        self.conv1 = nn.Conv2d(in_channels=input_channel,out_channels=8,kernel_size=3,stride=1)
        self.conv2 = nn.Conv2d(in_channels=8,out_channels=16,kernel_size=3,stride=1)
        self.fc1 = nn.Linear(64, 256)
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256,128)
        self.dropout2 = nn.Dropout(0.3)
        
        self.fc3 = nn.Linear(128, self.classes)
                
    def forward(self, x):
        
        x = F.max_pool2d(F.relu(self.conv1(x)),kernel_size=3)
        x = F.max_pool2d(F.relu(self.conv2(x)),kernel_size=3)
        x = F.relu(self.fc1(x.reshape(-1,x.shape[1] * x.shape[2]*x.shape[3])))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x1 = self.fc3(x)
               
        return x1

In [None]:
num_classes = [10,2]
model = MTL_Net(3,num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001,momentum=0.9, weight_decay=5e-4)

In [None]:
def train_model(model,trainloader,optim,criterion,epoch,device):
    model.train()
    train_loss,total,total_correct1,total_correct2 = 0,0,0,0
    
    for i,(inputs,tg1,tg2) in enumerate(tqdm(trainloader)):
        
        inputs,tg1,tg2 = inputs.to(device), tg1.to(device), tg2.to(device)
        optim.zero_grad()
        
        op1,op2 = model(inputs)
        loss1 = criterion(op1,tg1)
        loss2 = criterion(op2,tg2)
        
        total_loss =  loss1 + loss2
        total_loss.backward()
        
        optim.step()
        
        train_loss += loss1 + loss2
        _,pd1 = torch.max(op1.data,1)
        _,pd2 = torch.max(op2.data,1)
        
        total_correct1 += (pd1 == tg1).sum().item()
        total_correct2 += (pd2 == tg2).sum().item()
        
        total += tg1.size(0)
    
    print("Epoch: [{}]  loss: [{:.2f}] Original_task_acc [{:.2f}] animal_vs_non_animal_acc [{:.2f}]".format
                                                                          (epoch+1,train_loss/(i+1),
                                                                           (total_correct1*100/total),
                                                                          (total_correct2*100/total)))
    return train_loss/(i+1)

def train_single_model(model,trainloader,optim,criterion,epoch,device):
    model.train()
    train_loss,total,total_correct1 = 0,0,0
    
    for i,(inputs,tg1,tg2) in enumerate(tqdm(trainloader)):
        
        inputs,tg1,tg2 = inputs.to(device), tg1.to(device), tg2.to(device)
        optim.zero_grad()
        
        op = model(inputs)
        loss1 = criterion(op,tg1) # Change tg1 to tg2 or vice versa based on the task 
        
        total_loss = loss1
        total_loss.backward()
        
        optim.step()
        
        train_loss += loss1.item()
        _,pd1 = torch.max(op.data,1)
        
        total_correct1 += (pd1 == tg1).sum().item()
        
        total += tg1.size(0)
    
    print("Epoch: [{}]  loss: [{:.2f}] Acc [{:.2f}] ".format(epoch+1,train_loss/(i+1),
                                                                           (total_correct1*100/total),
                                                                          ))
    return train_loss/(i+1)

In [None]:
def test_model(model,testloader,optim,criterion,epoch,device):
    model.eval()
    test_loss,total,total_correct1,total_correct2 = 0,0,0,0
    
    with torch.no_grad():
        for i,(inputs,tg1,tg2) in enumerate(tqdm(testloader)):

            inputs,tg1,tg2 = inputs.to(device), tg1.to(device), tg2.to(device)

            op1,op2 = model(inputs)
            loss1 = criterion(op1,tg1)
            loss2 = criterion(op2,tg2)

            test_loss += loss1.item() + loss2.item()
            _,pd1 = torch.max(op1.data,1)
            _,pd2 = torch.max(op2.data,1)

            total_correct1 += (pd1 == tg1).sum().item()
            total_correct2 += (pd2 == tg2).sum().item()

            total += tg1.size(0)

    acc1 = 100. * total_correct1 / total
    acc2 = 100. * total_correct2 / total
    print("Test Epoch: [{}]  loss: [{:.2f}] Original_task_Acc [{:.2f}] animal_vs_non_animal_acc [{:.2f}]".format
                                                                          (epoch+1,test_loss/(i+1),
                                                                           acc1,acc2))
        
    return test_loss/(i+1), acc1, acc2

def test_single_model(model,testloader,optim,criterion,epoch,device):
    model.eval()
    test_loss,total,total_correct1,total_correct2 = 0,0,0,0
    
    with torch.no_grad():
        for i,(inputs,tg1,tg2) in enumerate(tqdm(testloader)):

            inputs,tg1,tg2 = inputs.to(device), tg1.to(device), tg2.to(device)

            op = model(inputs)
            loss1 = criterion(op,tg1)  # Change tg1 to tg2 or vice versa based on the task
            
            test_loss += loss1.item() 
            _,pd1 = torch.max(op.data,1)
           
            total_correct1 += (pd1 == tg1).sum().item()
            
            total += tg1.size(0)

    acc1 = 100. * total_correct1 / total

    print("Test Epoch: [{}]  loss: [{:.2f}] Acc [{:.2f}] ".format(epoch+1,test_loss/(i+1),
                                                                           acc1))
        
    return test_loss/(i+1), acc1

### Training Using MTL 

In [None]:
num_epochs = 50

for epoch in range(num_epochs):
    
    _ = train_model(model,train_loader,optimizer,criterion,epoch,device)
    _,_,_ = test_model(model,valid_loader,optimizer,criterion,epoch,device)    
    

In [None]:
_,_,_ = test_model(model,test_loader,optimizer,criterion,epoch,device)   

### Training a Single Task

In [None]:
num_classes = 10
model = Net(3,num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001,momentum=0.9, weight_decay=5e-4)

In [None]:
num_epochs = 50

for epoch in range(num_epochs):
    
    _ = train_single_model(model,train_loader,optimizer,criterion,epoch,device)
    _,_ = test_single_model(model,valid_loader,optimizer,criterion,epoch,device)    


In [None]:
_,_ = test_single_model(model,test_loader,optimizer,criterion,epoch,device)    