In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader,random_split,Dataset
import torch.optim as optim
from tqdm import tqdm
from wideresnet import *


In [2]:
trainset = datasets.CIFAR10(root='/home/aminul/data1/', train=True, download=False, transform=transforms.ToTensor())
testset = datasets.CIFAR10(root='/home/aminul/data1/', train=False, download=False, transform=transforms.ToTensor())

labels_list = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
non_animal = [0,1,8,9]
device = 'cuda:0'

In [3]:
class NewDataset(Dataset):
    
    def __init__(self,data,transform=None):
        self.data = data
        
    def __len__(self):
        return len(self.data)    
    
    def __getitem__(self,idx):
        image = self.data[idx][0]
        label1 = self.data[idx][1]
        label2 = 0 if self.data[idx][1] in non_animal else 1
        label3 = 0 if self.data[idx][1] > 5 else 1
        return image, label1, label2, label3

In [4]:
new_trainset = NewDataset(trainset,non_animal)
new_testset = NewDataset(testset,non_animal)

train_loader = DataLoader(new_trainset, batch_size=100, shuffle=True)
test_loader = DataLoader(new_testset, batch_size=100, shuffle=True)

In [6]:
def train(net,trainloader,optim,criterion,epoch,device):
    net.train()
    train_loss,total,total_correct1,total_correct2,total_correct3 = 0,0,0,0,0
    
    for i,(inputs,tg1,tg2,tg3) in enumerate(tqdm(trainloader)):
        
        inputs,tg1,tg2,tg3 = inputs.to(device), tg1.to(device), tg2.to(device), tg3.to(device)
        optim.zero_grad()
        
        op1,op2,op3 = net(inputs)
        loss1 = criterion(op1,tg1)
        loss2 = criterion(op2,tg2)
        loss3 = criterion(op3,tg3)
        
        loss1.backward(retain_graph=True)
        loss2.backward(retain_graph=True)
        loss3.backward()
        
        optim.step()
        
        train_loss += loss1.item() + loss2.item() + loss3.item()
        _,pd1 = torch.max(op1.data,1)
        _,pd2 = torch.max(op2.data,1)
        _,pd3 = torch.max(op3.data,1)
        total_correct1 += (pd1 == tg1).sum().item()
        total_correct2 += (pd2 == tg2).sum().item()
        total_correct3 += (pd3 == tg3).sum().item()
        total += tg1.size(0)
    
    print("Epoch: [{}]  loss: [{:.2f}] Orig_Acc [{:.2f}] animal_Acc [{:.2f}] random_Acc [{:.2f}] ".format
                                                                          (epoch+1,train_loss/(i+1),
                                                                           (total_correct1*100/total),
                                                                          (total_correct2*100/total)
                                                                         ,(total_correct3*100/total)))
    return train_loss/(i+1)

In [7]:
def tester(net,testloader,optim,criterion,epoch,device):
    net.eval()
    test_loss,total,total_correct1,total_correct2,total_correct3 = 0,0,0,0,0
    
    for i,(inputs,tg1,tg2,tg3) in enumerate(tqdm(testloader)):
        
        inputs,tg1,tg2,tg3 = inputs.to(device), tg1.to(device), tg2.to(device), tg3.to(device)
        optim.zero_grad()
        
        op1,op2,op3 = net(inputs)
        loss1 = criterion(op1,tg1)
        loss2 = criterion(op2,tg2)
        loss3 = criterion(op3,tg3)
        
        test_loss += loss1.item() + loss2.item() + loss3.item()
        _,pd1 = torch.max(op1.data,1)
        _,pd2 = torch.max(op2.data,1)
        _,pd3 = torch.max(op3.data,1)
        total_correct1 += (pd1 == tg1).sum().item()
        total_correct2 += (pd2 == tg2).sum().item()
        total_correct3 += (pd3 == tg3).sum().item()
        total += tg1.size(0)
        
    acc1 = 100. * total_correct1 / total
    acc2 = 100. * total_correct2 / total
    acc3 = 100. * total_correct3 / total
    print("\nTest Epoch #%d Loss: %.4f Orig_Acc: %.2f%% animal_Acc: %.2f%% random_Acc: %.2f%%" %(epoch+1,
                                                                                                  test_loss/(i+1),
                                                                                                  acc1,acc2,acc3))
        
    return test_loss/(i+1), acc1, acc2, acc3

In [8]:
total_classes = [10,2,2]
net = WideResNet(depth=28,num_classes = total_classes,widen_factor=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001,momentum=0.9, weight_decay=5e-4)

In [9]:
num_epochs = 20
train_loss,test_loss = [],[]

for epoch in range(num_epochs):
    
    a = train(net,train_loader,optimizer,criterion,epoch,device)
    c,_,_,_ = tester(net,test_loader,optimizer,criterion,epoch,device)    
    
    train_loss.append(a), test_loss.append(c)

100%|██████████| 500/500 [04:53<00:00,  1.70it/s]
  1%|          | 1/100 [00:00<00:15,  6.28it/s]

Epoch: [1]  loss: [2.71] Orig_Acc [34.07] animal_Acc [85.98] random_Acc [65.96] 


100%|██████████| 100/100 [00:13<00:00,  7.55it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #1 Loss: 2.3915 Orig_Acc: 41.36% animal_Acc: 90.61% random_Acc: 69.56%


100%|██████████| 500/500 [04:52<00:00,  1.71it/s]
  1%|          | 1/100 [00:00<00:12,  7.90it/s]

Epoch: [2]  loss: [2.09] Orig_Acc [49.90] animal_Acc [91.81] random_Acc [74.10] 


100%|██████████| 100/100 [00:13<00:00,  7.40it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #2 Loss: 2.1370 Orig_Acc: 49.43% animal_Acc: 89.74% random_Acc: 75.84%


100%|██████████| 500/500 [04:54<00:00,  1.70it/s]
  1%|          | 1/100 [00:00<00:12,  7.85it/s]

Epoch: [3]  loss: [1.73] Orig_Acc [59.50] animal_Acc [93.32] random_Acc [79.91] 


100%|██████████| 100/100 [00:13<00:00,  7.47it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #3 Loss: 1.7864 Orig_Acc: 57.34% animal_Acc: 93.59% random_Acc: 78.93%


100%|██████████| 500/500 [04:53<00:00,  1.71it/s]
  1%|          | 1/100 [00:00<00:12,  7.84it/s]

Epoch: [4]  loss: [1.47] Orig_Acc [65.50] animal_Acc [94.45] random_Acc [83.89] 


100%|██████████| 100/100 [00:13<00:00,  7.66it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #4 Loss: 3.4077 Orig_Acc: 43.65% animal_Acc: 88.02% random_Acc: 60.06%


100%|██████████| 500/500 [04:53<00:00,  1.70it/s]
  1%|          | 1/100 [00:00<00:12,  7.92it/s]

Epoch: [5]  loss: [1.27] Orig_Acc [70.04] animal_Acc [95.29] random_Acc [86.82] 


100%|██████████| 100/100 [00:13<00:00,  7.65it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #5 Loss: 1.6184 Orig_Acc: 62.21% animal_Acc: 94.23% random_Acc: 80.08%


100%|██████████| 500/500 [04:53<00:00,  1.71it/s]
  1%|          | 1/100 [00:00<00:12,  7.88it/s]

Epoch: [6]  loss: [1.09] Orig_Acc [74.06] animal_Acc [95.97] random_Acc [89.19] 


100%|██████████| 100/100 [00:12<00:00,  7.79it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #6 Loss: 1.2777 Orig_Acc: 70.44% animal_Acc: 95.32% random_Acc: 86.02%


100%|██████████| 500/500 [04:52<00:00,  1.71it/s]
  1%|          | 1/100 [00:00<00:12,  7.95it/s]

Epoch: [7]  loss: [0.92] Orig_Acc [77.81] animal_Acc [96.54] random_Acc [91.68] 


100%|██████████| 100/100 [00:13<00:00,  7.69it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #7 Loss: 1.4999 Orig_Acc: 69.02% animal_Acc: 93.25% random_Acc: 83.41%


100%|██████████| 500/500 [04:53<00:00,  1.70it/s]
  1%|          | 1/100 [00:00<00:14,  6.81it/s]

Epoch: [8]  loss: [0.77] Orig_Acc [80.87] animal_Acc [97.11] random_Acc [93.49] 


100%|██████████| 100/100 [00:13<00:00,  7.39it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #8 Loss: 1.5193 Orig_Acc: 69.23% animal_Acc: 95.13% random_Acc: 81.99%


100%|██████████| 500/500 [04:52<00:00,  1.71it/s]
  1%|          | 1/100 [00:00<00:12,  7.90it/s]

Epoch: [9]  loss: [0.63] Orig_Acc [83.93] animal_Acc [97.55] random_Acc [95.42] 


100%|██████████| 100/100 [00:12<00:00,  7.85it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #9 Loss: 1.9477 Orig_Acc: 67.62% animal_Acc: 93.22% random_Acc: 80.18%


100%|██████████| 500/500 [04:53<00:00,  1.71it/s]
  1%|          | 1/100 [00:00<00:12,  7.87it/s]

Epoch: [10]  loss: [0.51] Orig_Acc [86.76] animal_Acc [98.19] random_Acc [96.98] 


100%|██████████| 100/100 [00:13<00:00,  7.57it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #10 Loss: 1.4635 Orig_Acc: 70.54% animal_Acc: 95.56% random_Acc: 85.96%


100%|██████████| 500/500 [04:53<00:00,  1.70it/s]
  1%|          | 1/100 [00:00<00:12,  7.90it/s]

Epoch: [11]  loss: [0.41] Orig_Acc [89.12] animal_Acc [98.50] random_Acc [97.73] 


100%|██████████| 100/100 [00:13<00:00,  7.56it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #11 Loss: 1.5558 Orig_Acc: 70.80% animal_Acc: 95.16% random_Acc: 85.27%


100%|██████████| 500/500 [04:52<00:00,  1.71it/s]
  1%|          | 1/100 [00:00<00:12,  7.84it/s]

Epoch: [12]  loss: [0.33] Orig_Acc [91.27] animal_Acc [99.00] random_Acc [98.35] 


100%|██████████| 100/100 [00:13<00:00,  7.57it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #12 Loss: 1.3832 Orig_Acc: 74.23% animal_Acc: 95.64% random_Acc: 86.84%


100%|██████████| 500/500 [04:53<00:00,  1.70it/s]
  1%|          | 1/100 [00:00<00:13,  7.18it/s]

Epoch: [13]  loss: [0.27] Orig_Acc [93.05] animal_Acc [99.21] random_Acc [98.64] 


100%|██████████| 100/100 [00:13<00:00,  7.51it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #13 Loss: 1.4643 Orig_Acc: 74.80% animal_Acc: 95.42% random_Acc: 86.74%


100%|██████████| 500/500 [04:54<00:00,  1.70it/s]
  1%|          | 1/100 [00:00<00:12,  7.85it/s]

Epoch: [14]  loss: [0.20] Orig_Acc [94.74] animal_Acc [99.55] random_Acc [99.03] 


100%|██████████| 100/100 [00:13<00:00,  7.54it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #14 Loss: 2.4584 Orig_Acc: 68.37% animal_Acc: 95.62% random_Acc: 81.49%


100%|██████████| 500/500 [04:53<00:00,  1.70it/s]
  1%|          | 1/100 [00:00<00:12,  7.84it/s]

Epoch: [15]  loss: [0.14] Orig_Acc [96.57] animal_Acc [99.69] random_Acc [99.43] 


100%|██████████| 100/100 [00:12<00:00,  7.85it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #15 Loss: 1.4317 Orig_Acc: 75.58% animal_Acc: 95.69% random_Acc: 87.38%


100%|██████████| 500/500 [04:52<00:00,  1.71it/s]
  1%|          | 1/100 [00:00<00:13,  7.34it/s]

Epoch: [16]  loss: [0.09] Orig_Acc [98.15] animal_Acc [99.78] random_Acc [99.67] 


100%|██████████| 100/100 [00:13<00:00,  7.38it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #16 Loss: 1.4018 Orig_Acc: 77.22% animal_Acc: 96.22% random_Acc: 88.41%


100%|██████████| 500/500 [04:53<00:00,  1.70it/s]
  1%|          | 1/100 [00:00<00:12,  7.84it/s]

Epoch: [17]  loss: [0.06] Orig_Acc [98.93] animal_Acc [99.92] random_Acc [99.79] 


100%|██████████| 100/100 [00:12<00:00,  7.73it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #17 Loss: 1.4052 Orig_Acc: 77.62% animal_Acc: 96.13% random_Acc: 88.96%


100%|██████████| 500/500 [04:53<00:00,  1.71it/s]
  1%|          | 1/100 [00:00<00:12,  7.87it/s]

Epoch: [18]  loss: [0.04] Orig_Acc [99.41] animal_Acc [99.94] random_Acc [99.84] 


100%|██████████| 100/100 [00:12<00:00,  7.73it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #18 Loss: 1.4343 Orig_Acc: 77.58% animal_Acc: 95.40% random_Acc: 89.10%


100%|██████████| 500/500 [04:53<00:00,  1.70it/s]
  1%|          | 1/100 [00:00<00:13,  7.20it/s]

Epoch: [19]  loss: [0.02] Orig_Acc [99.71] animal_Acc [99.97] random_Acc [99.92] 


100%|██████████| 100/100 [00:13<00:00,  7.30it/s]
  0%|          | 0/500 [00:00<?, ?it/s]


Test Epoch #19 Loss: 1.5017 Orig_Acc: 77.97% animal_Acc: 96.42% random_Acc: 86.56%


100%|██████████| 500/500 [04:53<00:00,  1.71it/s]
  1%|          | 1/100 [00:00<00:12,  7.89it/s]

Epoch: [20]  loss: [0.02] Orig_Acc [99.72] animal_Acc [99.95] random_Acc [99.91] 


100%|██████████| 100/100 [00:13<00:00,  7.65it/s]


Test Epoch #20 Loss: 1.4501 Orig_Acc: 78.12% animal_Acc: 96.52% random_Acc: 88.54%



