In [1]:
import sys
sys.path.append(sys.path[0]+'/../')

import torch
import torch.nn as nn
import time
from tqdm import tqdm
from torchsummary import summary
from torch.utils.data import DataLoader

from dataloaders.minc_dataloader import MINCDataset, MINCDataLoader
from utils.loss_resnet import LossRN



In [2]:
minc_path = '/home/ahmed/workspace/notebook/matrec/datasets/minc'
labels_path = '/home/ahmed/workspace/notebook/matrec/datasets/minc/train.txt'
labels_path_t = '/home/ahmed/workspace/notebook/matrec/datasets/minc/test.txt'
BATCH_SIZE = 4
TRAIN_ITER = 10000
TEST_ITER = 400
start = 220000
lr = 0.0000005
size = 256

train_dataloader = MINCDataLoader(minc_path, labels_path, batch_size=BATCH_SIZE, size=size, f=0.16)
dataset_test = MINCDataset(minc_path, labels_path_t, size=(size, size))
test_loader = DataLoader(dataset=dataset_test, batch_size=BATCH_SIZE, num_workers=0, pin_memory=False, shuffle=True)

In [3]:
from torchvision.models import googlenet, GoogLeNet_Weights

checkpoint = "../weights/googlenet_minc.pth"

model = googlenet(weights=GoogLeNet_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 23, bias=True)
model.load_state_dict(torch.load(checkpoint), strict=False)

model = model.train()
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss = LossRN()

In [None]:
for epc in range(1,3):
    r = tqdm(range(start, len(train_dataloader)), leave=False, desc="Epoch {} starting from iteration {}: ".format(epc, start), 
                                                                     total=len(train_dataloader)-start)    
    
    for idx in r:
        x, y = train_dataloader[idx]
        y_pred = model(x.cuda())
        lf = loss.compute(y_pred, y)
        lf.backward()
        optimizer.step()
        optimizer.zero_grad()
        r.set_postfix(loss=lf.item())
        
        if idx % TRAIN_ITER == 0 and idx != start:
            #save checkpoint
            torch.save(model.state_dict(), checkpoint)
            #test loss run on val
            with torch.no_grad():
                ac = 0
                for i in range(TEST_ITER):
                    x, y = next(iter(test_loader))
                    y_pred = model(x.cuda())
                    ac = ac + loss.accuracy(y_pred, y)
                ac = ac/TEST_ITER
                print('Test N ' + str(epc) + '-' + str(idx//TRAIN_ITER) + ' : ' + str(float(ac)) + '\n')
        start = 0

Epoch 1 starting from iteration 300000:  18%|█▊        | 10003/56502 [11:51<60:03:10,  4.65s/it, loss=0.0495]

Test N 1-31 : 72.25



Epoch 1 starting from iteration 300000:  35%|███▌      | 20002/56502 [24:04<66:13:42,  6.53s/it, loss=0.114] 

Test N 1-32 : 72.9375



Epoch 1 starting from iteration 300000:  53%|█████▎    | 30003/56502 [36:28<30:44:17,  4.18s/it, loss=1.45]  

Test N 1-33 : 73.5



Epoch 1 starting from iteration 300000:  71%|███████   | 40003/56502 [48:45<23:01:31,  5.02s/it, loss=0.589]  

Test N 1-34 : 72.125



Epoch 1 starting from iteration 300000:  88%|████████▊ | 50003/56502 [1:01:09<8:26:26,  4.68s/it, loss=0.422] 

Test N 1-35 : 73.9375



Epoch 2 starting from iteration 0:   3%|▎         | 10003/356502 [12:25<397:26:41,  4.13s/it, loss=0.26]      

Test N 2-1 : 72.75



Epoch 2 starting from iteration 0:   6%|▌         | 20004/356502 [24:51<349:23:39,  3.74s/it, loss=0.0245]

Test N 2-2 : 72.625



Epoch 2 starting from iteration 0:   8%|▊         | 30003/356502 [37:26<391:41:23,  4.32s/it, loss=1.05]  

Test N 2-3 : 73.125



Epoch 2 starting from iteration 0:  11%|█         | 40003/356502 [49:35<441:34:40,  5.02s/it, loss=0.181] 

Test N 2-4 : 73.5



Epoch 2 starting from iteration 0:  14%|█▍        | 50004/356502 [1:01:25<303:04:59,  3.56s/it, loss=0.0976]

Test N 2-5 : 73.3125



Epoch 2 starting from iteration 0:  17%|█▋        | 60001/356502 [1:13:15<606:49:47,  7.37s/it, loss=0.083] 

Test N 2-6 : 74.25



Epoch 2 starting from iteration 0:  20%|█▉        | 70003/356502 [1:25:07<278:27:32,  3.50s/it, loss=2.05]  

Test N 2-7 : 74.5625



Epoch 2 starting from iteration 0:  22%|██▏       | 80003/356502 [1:37:40<334:32:26,  4.36s/it, loss=0.278] 

Test N 2-8 : 74.75



Epoch 2 starting from iteration 0:  25%|██▌       | 90003/356502 [1:50:21<340:33:16,  4.60s/it, loss=0.997] 

Test N 2-9 : 72.9375



Epoch 2 starting from iteration 0:  28%|██▊       | 100003/356502 [2:03:12<336:35:02,  4.72s/it, loss=0.371]

Test N 2-10 : 73.8125



Epoch 2 starting from iteration 0:  31%|███       | 110003/356502 [2:16:00<331:03:03,  4.83s/it, loss=0.131] 

Test N 2-11 : 73.6875



Epoch 2 starting from iteration 0:  34%|███▎      | 120003/356502 [2:28:43<316:36:34,  4.82s/it, loss=0.26]  

Test N 2-12 : 72.375



Epoch 2 starting from iteration 0:  36%|███▋      | 130003/356502 [2:41:30<298:40:29,  4.75s/it, loss=0.773] 

Test N 2-13 : 72.6875



Epoch 2 starting from iteration 0:  39%|███▉      | 140004/356502 [2:54:18<248:21:45,  4.13s/it, loss=0.211] 

Test N 2-14 : 73.875



Epoch 2 starting from iteration 0:  42%|████▏     | 150003/356502 [3:07:08<296:27:57,  5.17s/it, loss=0.756] 

Test N 2-15 : 73.0625



Epoch 2 starting from iteration 0:  45%|████▍     | 160003/356502 [3:19:50<269:40:30,  4.94s/it, loss=0.573]

Test N 2-16 : 73.1875



Epoch 2 starting from iteration 0:  48%|████▊     | 170004/356502 [3:32:54<224:03:03,  4.32s/it, loss=0.117] 

Test N 2-17 : 72.5625



Epoch 2 starting from iteration 0:  50%|█████     | 180001/356502 [3:45:56<339:36:40,  6.93s/it, loss=0.254] 

Test N 2-18 : 73.4375



Epoch 2 starting from iteration 0:  53%|█████▎    | 190002/356502 [3:59:06<252:34:41,  5.46s/it, loss=0.228] 

Test N 2-19 : 71.875



Epoch 2 starting from iteration 0:  56%|█████▌    | 200003/356502 [4:11:27<200:12:34,  4.61s/it, loss=0.953] 

Test N 2-20 : 72.6875



Epoch 2 starting from iteration 0:  59%|█████▉    | 210003/356502 [4:23:39<207:05:14,  5.09s/it, loss=0.449] 

Test N 2-21 : 74.375



Epoch 2 starting from iteration 0:  62%|██████▏   | 220004/356502 [4:35:38<162:57:30,  4.30s/it, loss=0.357] 

Test N 2-22 : 73.75



Epoch 2 starting from iteration 0:  62%|██████▏   | 221983/356502 [4:37:54<2:10:25, 17.19it/s, loss=0.142]  