In [1]:
import sys
sys.path.append(sys.path[0]+'/../')

import torch
import torch.nn as nn
import time
from tqdm import tqdm
from torchsummary import summary
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from dataloaders.minc_dataloader import MINCDataset, MINCDataLoader
from utils.loss_resnet import LossRN

from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
from torchvision.models import swin_t, Swin_T_Weights, swin_s, Swin_S_Weights, swin_b, Swin_B_Weights, swin_v2_s, Swin_V2_S_Weights


In [2]:
minc_path = '/home/ahmed/workspace/notebook/matrec/datasets/minc'
labels_path = '/home/ahmed/workspace/notebook/matrec/datasets/minc/train.txt'
labels_path_t = '/home/ahmed/workspace/notebook/matrec/datasets/minc/test.txt'
BATCH_SIZE = 8
TRAIN_ITER = 10000
TEST_ITER = 400
LOAD = True
start = 50000
lr = 0.00001
size = 256

train_dataloader = MINCDataLoader(minc_path, labels_path, batch_size=BATCH_SIZE, size=size, f=0.16)

dataset_test = MINCDataset(minc_path, labels_path_t, size=(size, size))
test_loader = DataLoader(dataset=dataset_test, batch_size=BATCH_SIZE, num_workers=0, pin_memory=False, shuffle=True)

In [3]:
checkpoint = "../weights/enet_minc.pth"


model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
model.classifier = nn.Sequential(
        nn.Dropout(p=0.4, inplace=True),
        nn.Linear(in_features=1280, out_features=23, bias=True)
    )
model.load_state_dict(torch.load(checkpoint), strict=False)
model = model.train()
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss = LossRN()

In [4]:
"""from torchvision.models import resnet18

checkpoint = "../weights/resnet18_minc.pth"

model = resnet18()
model.fc = nn.Linear(model.fc.in_features, 23)
model.load_state_dict(torch.load(checkpoint), strict=False)


model = model.train()
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss = LossRN()"""

'from torchvision.models import resnet18\n\ncheckpoint = "../weights/resnet18_minc.pth"\n\nmodel = resnet18()\nmodel.fc = nn.Linear(model.fc.in_features, 23)\nmodel.load_state_dict(torch.load(checkpoint), strict=False)\n\n\nmodel = model.train()\nmodel = model.cuda()\noptimizer = torch.optim.Adam(model.parameters(), lr=lr)\nloss = LossRN()'

In [5]:
#add augmentation
#rescaling [0, 1] and normalization (mean, std)
#effichientnet v2 s stopped at itr=50000 bs=8
for epc in range(1,3):
    r = tqdm(range(start, len(train_dataloader)), leave=False, desc="Epoch {} starting from iteration {}: ".format(epc, start), 
                                                                     total=len(train_dataloader)-start)    
    
    for idx in r:
        x, y = train_dataloader[idx]
        y_pred = model(x.cuda())
        lf = loss.compute(y_pred, y)
        lf.backward()
        optimizer.step()
        optimizer.zero_grad()
        r.set_postfix(loss=lf.item())
        
        if idx % TRAIN_ITER == 0 and idx != start:
            #save checkpoint
            torch.save(model.state_dict(), checkpoint)
            #test loss run on val
            with torch.no_grad():
                ac = 0
                for i in range(TEST_ITER):
                    x, y = next(iter(test_loader))
                    y_pred = model(x.cuda())
                    ac = ac + loss.accuracy(y_pred, y)
                ac = ac/TEST_ITER
                print('Test N ' + str(epc) + '-' + str(idx//TRAIN_ITER) + ' : ' + str(float(ac)) + '\n')
    start = 0
    torch.save(model.state_dict(), checkpoint)

Epoch 1 starting from iteration 150000:  35%|███▌      | 10001/28251 [31:32<143:03:30, 28.22s/it, loss=0.335]

Test N 1-16 : 77.0625



Epoch 1 starting from iteration 150000:  71%|███████   | 20002/28251 [1:04:07<45:04:39, 19.67s/it, loss=0.403]

Test N 1-17 : 78.21875



Epoch 2 starting from iteration 0:   6%|▌         | 10002/178251 [32:44<925:08:08, 19.79s/it, loss=0.2]       

Test N 2-1 : 78.875



Epoch 2 starting from iteration 0:  11%|█         | 20001/178251 [1:05:46<1251:57:00, 28.48s/it, loss=0.631]

Test N 2-2 : 79.375



Epoch 2 starting from iteration 0:  17%|█▋        | 30001/178251 [1:38:08<1123:43:41, 27.29s/it, loss=0.245]

Test N 2-3 : 76.75



Epoch 2 starting from iteration 0:  22%|██▏       | 40002/178251 [2:10:15<761:45:37, 19.84s/it, loss=0.615] 

Test N 2-4 : 78.0625



Epoch 2 starting from iteration 0:  28%|██▊       | 50002/178251 [2:42:29<686:55:15, 19.28s/it, loss=0.456] 

Test N 2-5 : 79.5



                                                                                                           

KeyboardInterrupt: 