In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader,random_split,Dataset
import torch.optim as optim
from tqdm import tqdm

In [2]:
trainset = datasets.CIFAR10(root='/home/aminul/data1/', train=True, download=False, transform=transforms.ToTensor())
testset = datasets.CIFAR10(root='/home/aminul/data1/', train=False, download=False, transform=transforms.ToTensor())

labels_list = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
non_animal = [0,1,8,9]
device = 'cuda:1'

In [3]:
class NewDataset(Dataset):
    
    def __init__(self,data,transform=None):
        self.data = data
        
    def __len__(self):
        return len(self.data)    
    
    def __getitem__(self,idx):
        image = self.data[idx][0]
        label1 = self.data[idx][1]          #original label 
        label2 = 0 if self.data[idx][1] in non_animal else 1       #animal or non-animal
        label3 = 0 if self.data[idx][1] > 5 else 1   #random labels
        return image, label1, label2, label3

In [4]:
new_trainset = NewDataset(trainset,non_animal)
new_testset = NewDataset(testset,non_animal)

train_loader = DataLoader(new_trainset, batch_size=100, shuffle=True)
test_loader = DataLoader(new_testset, batch_size=100, shuffle=True)

In [5]:
class Net(nn.Module):
    def __init__(self, input_channel, num_class):
        super(Net,self).__init__()
        
        self.classes = num_class
        
        self.conv1 = nn.Conv2d(in_channels=input_channel,out_channels=8,kernel_size=3,stride=1)
        self.conv2 = nn.Conv2d(in_channels=8,out_channels=16,kernel_size=3,stride=1)
        self.fc1 = nn.Linear(64, 256)
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256,128)
        self.dropout2 = nn.Dropout(0.3)
        
        self.fc3 = nn.Linear(128, self.classes[0])
        self.fc4 = nn.Linear(128, self.classes[1])
        self.fc5 = nn.Linear(128, self.classes[2])
        
    def forward(self, x):
        
        x = F.max_pool2d(F.relu(self.conv1(x)),kernel_size=3)
        x = F.max_pool2d(F.relu(self.conv2(x)),kernel_size=3)
        x = F.relu(self.fc1(x.reshape(-1,x.shape[1] * x.shape[2]*x.shape[3])))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x1 = self.fc3(x)
        x2 = self.fc4(x)
        x3 = self.fc5(x)
        
        return x1,x2,x3 

In [6]:
num_classes = [10,2,2]
net = Net(3,num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001,momentum=0.9, weight_decay=5e-4)

In [7]:
def train(net,trainloader,optim,criterion,epoch,device):
    net.train()
    train_loss,total,total_correct1,total_correct2,total_correct3 = 0,0,0,0,0
    
    for i,(inputs,tg1,tg2,tg3) in enumerate(tqdm(trainloader)):
        
        inputs,tg1,tg2,tg3 = inputs.to(device), tg1.to(device), tg2.to(device), tg3.to(device)
        optim.zero_grad()
        
        op1,op2,op3 = net(inputs)
        loss1 = criterion(op1,tg1)
        loss2 = criterion(op2,tg2)
        loss3 = criterion(op3,tg3)
        
        loss1.backward(retain_graph=True)
        loss2.backward(retain_graph=True)
        loss3.backward()
        
        optim.step()
        
        train_loss += loss1.item() + loss2.item() + loss3.item()
        _,pd1 = torch.max(op1.data,1)
        _,pd2 = torch.max(op2.data,1)
        _,pd3 = torch.max(op3.data,1)
        total_correct1 += (pd1 == tg1).sum().item()
        total_correct2 += (pd2 == tg2).sum().item()
        total_correct3 += (pd3 == tg3).sum().item()
        total += tg1.size(0)
    
    print("Epoch: [{}]  loss: [{:.2f}] Orig_Acc [{:.2f}] animal_Acc [{:.2f}] random_Acc [{:.2f}] ".format
                                                                          (epoch+1,train_loss/(i+1),
                                                                           (total_correct1*100/total),
                                                                          (total_correct2*100/total)
                                                                         ,(total_correct3*100/total)))
    return train_loss/(i+1)

In [8]:
def tester(net,testloader,optim,criterion,epoch,device):
    net.eval()
    test_loss,total,total_correct1,total_correct2,total_correct3 = 0,0,0,0,0
    
    for i,(inputs,tg1,tg2,tg3) in enumerate(tqdm(testloader)):
        
        inputs,tg1,tg2,tg3 = inputs.to(device), tg1.to(device), tg2.to(device), tg3.to(device)
        
        op1,op2,op3 = net(inputs)
        loss1 = criterion(op1,tg1)
        loss2 = criterion(op2,tg2)
        loss3 = criterion(op3,tg3)
        
        test_loss += loss1.item() + loss2.item() + loss3.item()
        _,pd1 = torch.max(op1.data,1)
        _,pd2 = torch.max(op2.data,1)
        _,pd3 = torch.max(op3.data,1)
        total_correct1 += (pd1 == tg1).sum().item()
        total_correct2 += (pd2 == tg2).sum().item()
        total_correct3 += (pd3 == tg3).sum().item()
        total += tg1.size(0)
        
    acc1 = 100. * total_correct1 / total
    acc2 = 100. * total_correct2 / total
    acc3 = 100. * total_correct3 / total
    print("\nTest Epoch #%d Loss: %.4f Orig_Acc: %.2f%% animal_Acc: %.2f%% random_Acc: %.2f%%" %(epoch+1,
                                                                                                  test_loss/(i+1),
                                                                                                  acc1,acc2,acc3))
        
    return test_loss/(i+1), acc1, acc2, acc3

In [9]:
num_epochs = 20
train_loss,test_loss = [],[]

for epoch in range(num_epochs):
    
    a = train(net,train_loader,optimizer,criterion,epoch,device)
    c,_,_,_ = tester(net,test_loader,optimizer,criterion,epoch,device)    
    
    train_loss.append(a), test_loss.append(c)

100%|██████████| 500/500 [00:36<00:00, 13.79it/s]
  2%|▏         | 2/100 [00:00<00:06, 14.54it/s]

Epoch: [1]  loss: [3.65] Orig_Acc [9.97] animal_Acc [59.87] random_Acc [60.00] 


100%|██████████| 100/100 [00:06<00:00, 16.12it/s]
  0%|          | 2/500 [00:00<00:35, 14.21it/s]


Test Epoch #1 Loss: 3.6406 Orig_Acc: 10.49% animal_Acc: 60.00% random_Acc: 60.00%


100%|██████████| 500/500 [00:34<00:00, 14.29it/s]
  2%|▏         | 2/100 [00:00<00:05, 16.78it/s]

Epoch: [2]  loss: [3.63] Orig_Acc [11.40] animal_Acc [60.16] random_Acc [60.00] 


100%|██████████| 100/100 [00:06<00:00, 16.61it/s]
  0%|          | 2/500 [00:00<00:37, 13.30it/s]


Test Epoch #2 Loss: 3.5974 Orig_Acc: 14.51% animal_Acc: 61.83% random_Acc: 60.00%


100%|██████████| 500/500 [00:34<00:00, 14.50it/s]
  2%|▏         | 2/100 [00:00<00:05, 17.32it/s]

Epoch: [3]  loss: [3.50] Orig_Acc [16.26] animal_Acc [72.20] random_Acc [59.95] 


100%|██████████| 100/100 [00:06<00:00, 16.62it/s]
  0%|          | 2/500 [00:00<00:33, 15.04it/s]


Test Epoch #3 Loss: 3.3345 Orig_Acc: 18.38% animal_Acc: 77.88% random_Acc: 59.98%


100%|██████████| 500/500 [00:35<00:00, 14.28it/s]
  2%|▏         | 2/100 [00:00<00:05, 16.99it/s]

Epoch: [4]  loss: [3.25] Orig_Acc [19.88] animal_Acc [79.65] random_Acc [60.06] 


100%|██████████| 100/100 [00:06<00:00, 16.37it/s]
  0%|          | 2/500 [00:00<00:33, 15.04it/s]


Test Epoch #4 Loss: 3.1083 Orig_Acc: 22.03% animal_Acc: 82.63% random_Acc: 61.36%


100%|██████████| 500/500 [00:34<00:00, 14.38it/s]
  2%|▏         | 2/100 [00:00<00:06, 15.91it/s]

Epoch: [5]  loss: [3.06] Orig_Acc [22.00] animal_Acc [83.33] random_Acc [60.57] 


100%|██████████| 100/100 [00:06<00:00, 15.15it/s]
  0%|          | 2/500 [00:00<00:32, 15.15it/s]


Test Epoch #5 Loss: 2.9630 Orig_Acc: 24.82% animal_Acc: 84.60% random_Acc: 61.61%


100%|██████████| 500/500 [00:34<00:00, 14.62it/s]
  2%|▏         | 2/100 [00:00<00:05, 16.45it/s]

Epoch: [6]  loss: [2.95] Orig_Acc [24.13] animal_Acc [85.21] random_Acc [60.78] 


100%|██████████| 100/100 [00:06<00:00, 15.83it/s]
  0%|          | 2/500 [00:00<00:43, 11.42it/s]


Test Epoch #6 Loss: 2.8499 Orig_Acc: 28.27% animal_Acc: 87.04% random_Acc: 62.40%


100%|██████████| 500/500 [00:34<00:00, 14.31it/s]
  2%|▏         | 2/100 [00:00<00:06, 14.37it/s]

Epoch: [7]  loss: [2.89] Orig_Acc [26.15] animal_Acc [86.40] random_Acc [61.14] 


100%|██████████| 100/100 [00:06<00:00, 16.59it/s]
  0%|          | 2/500 [00:00<00:37, 13.11it/s]


Test Epoch #7 Loss: 2.8008 Orig_Acc: 29.35% animal_Acc: 87.75% random_Acc: 62.87%


100%|██████████| 500/500 [00:34<00:00, 14.45it/s]
  2%|▏         | 2/100 [00:00<00:06, 14.26it/s]

Epoch: [8]  loss: [2.85] Orig_Acc [27.56] animal_Acc [86.74] random_Acc [61.56] 


100%|██████████| 100/100 [00:06<00:00, 16.41it/s]
  0%|          | 2/500 [00:00<00:43, 11.44it/s]


Test Epoch #8 Loss: 2.7904 Orig_Acc: 30.62% animal_Acc: 86.06% random_Acc: 63.44%


100%|██████████| 500/500 [00:34<00:00, 14.51it/s]
  2%|▏         | 2/100 [00:00<00:06, 15.92it/s]

Epoch: [9]  loss: [2.79] Orig_Acc [29.54] animal_Acc [87.15] random_Acc [62.31] 


100%|██████████| 100/100 [00:06<00:00, 15.32it/s]
  0%|          | 2/500 [00:00<00:36, 13.74it/s]


Test Epoch #9 Loss: 2.7002 Orig_Acc: 32.62% animal_Acc: 88.23% random_Acc: 62.93%


100%|██████████| 500/500 [00:34<00:00, 14.36it/s]
  2%|▏         | 2/100 [00:00<00:05, 17.86it/s]

Epoch: [10]  loss: [2.76] Orig_Acc [31.01] animal_Acc [87.26] random_Acc [62.53] 


100%|██████████| 100/100 [00:06<00:00, 16.26it/s]
  0%|          | 2/500 [00:00<00:39, 12.68it/s]


Test Epoch #10 Loss: 2.6560 Orig_Acc: 34.27% animal_Acc: 88.35% random_Acc: 65.29%


100%|██████████| 500/500 [00:34<00:00, 14.40it/s]
  2%|▏         | 2/100 [00:00<00:05, 17.72it/s]

Epoch: [11]  loss: [2.71] Orig_Acc [32.78] animal_Acc [87.61] random_Acc [63.52] 


100%|██████████| 100/100 [00:06<00:00, 16.29it/s]
  0%|          | 2/500 [00:00<00:39, 12.63it/s]


Test Epoch #11 Loss: 2.7431 Orig_Acc: 33.71% animal_Acc: 85.12% random_Acc: 65.81%


100%|██████████| 500/500 [00:34<00:00, 14.45it/s]
  2%|▏         | 2/100 [00:00<00:06, 15.82it/s]

Epoch: [12]  loss: [2.68] Orig_Acc [34.30] animal_Acc [87.97] random_Acc [64.23] 


100%|██████████| 100/100 [00:06<00:00, 15.44it/s]
  0%|          | 2/500 [00:00<00:40, 12.43it/s]


Test Epoch #12 Loss: 2.5785 Orig_Acc: 36.75% animal_Acc: 88.70% random_Acc: 66.23%


100%|██████████| 500/500 [00:35<00:00, 14.17it/s]
  2%|▏         | 2/100 [00:00<00:06, 16.11it/s]

Epoch: [13]  loss: [2.64] Orig_Acc [35.51] animal_Acc [88.09] random_Acc [65.02] 


100%|██████████| 100/100 [00:06<00:00, 15.98it/s]
  0%|          | 2/500 [00:00<00:37, 13.34it/s]


Test Epoch #13 Loss: 2.5440 Orig_Acc: 38.48% animal_Acc: 88.90% random_Acc: 66.90%


100%|██████████| 500/500 [00:34<00:00, 14.38it/s]
  2%|▏         | 2/100 [00:00<00:06, 14.64it/s]

Epoch: [14]  loss: [2.62] Orig_Acc [36.15] animal_Acc [88.38] random_Acc [65.33] 


100%|██████████| 100/100 [00:06<00:00, 16.19it/s]
  0%|          | 2/500 [00:00<00:33, 14.81it/s]


Test Epoch #14 Loss: 2.5392 Orig_Acc: 38.50% animal_Acc: 88.32% random_Acc: 67.17%


100%|██████████| 500/500 [00:34<00:00, 14.49it/s]
  2%|▏         | 2/100 [00:00<00:06, 16.07it/s]

Epoch: [15]  loss: [2.59] Orig_Acc [37.09] animal_Acc [88.35] random_Acc [65.69] 


100%|██████████| 100/100 [00:06<00:00, 16.28it/s]
  0%|          | 2/500 [00:00<00:35, 14.20it/s]


Test Epoch #15 Loss: 2.4967 Orig_Acc: 39.69% animal_Acc: 89.00% random_Acc: 67.78%


100%|██████████| 500/500 [00:34<00:00, 14.70it/s]
  2%|▏         | 2/100 [00:00<00:05, 17.68it/s]

Epoch: [16]  loss: [2.58] Orig_Acc [37.53] animal_Acc [88.37] random_Acc [66.13] 


100%|██████████| 100/100 [00:05<00:00, 16.92it/s]
  0%|          | 2/500 [00:00<00:37, 13.25it/s]


Test Epoch #16 Loss: 2.4914 Orig_Acc: 40.39% animal_Acc: 88.85% random_Acc: 68.01%


100%|██████████| 500/500 [00:34<00:00, 14.38it/s]
  2%|▏         | 2/100 [00:00<00:05, 16.49it/s]

Epoch: [17]  loss: [2.56] Orig_Acc [37.93] animal_Acc [88.30] random_Acc [66.45] 


100%|██████████| 100/100 [00:06<00:00, 16.01it/s]
  0%|          | 2/500 [00:00<00:41, 11.88it/s]


Test Epoch #17 Loss: 2.4563 Orig_Acc: 40.94% animal_Acc: 89.37% random_Acc: 69.01%


100%|██████████| 500/500 [00:34<00:00, 14.46it/s]
  2%|▏         | 2/100 [00:00<00:05, 17.48it/s]

Epoch: [18]  loss: [2.54] Orig_Acc [38.78] animal_Acc [88.64] random_Acc [66.85] 


100%|██████████| 100/100 [00:05<00:00, 17.04it/s]
  0%|          | 2/500 [00:00<00:33, 14.95it/s]


Test Epoch #18 Loss: 2.4433 Orig_Acc: 40.72% animal_Acc: 89.51% random_Acc: 68.82%


100%|██████████| 500/500 [00:34<00:00, 14.59it/s]
  2%|▏         | 2/100 [00:00<00:05, 17.78it/s]

Epoch: [19]  loss: [2.52] Orig_Acc [39.30] animal_Acc [88.66] random_Acc [67.09] 


100%|██████████| 100/100 [00:06<00:00, 16.32it/s]
  0%|          | 2/500 [00:00<00:38, 12.93it/s]


Test Epoch #19 Loss: 2.4188 Orig_Acc: 41.71% animal_Acc: 89.43% random_Acc: 70.15%


100%|██████████| 500/500 [00:34<00:00, 14.43it/s]
  2%|▏         | 2/100 [00:00<00:06, 14.46it/s]

Epoch: [20]  loss: [2.50] Orig_Acc [39.90] animal_Acc [88.75] random_Acc [67.35] 


100%|██████████| 100/100 [00:05<00:00, 16.85it/s]


Test Epoch #20 Loss: 2.4025 Orig_Acc: 42.32% animal_Acc: 89.41% random_Acc: 69.88%



