In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.optim as optim
from torchvision.models import resnet18, ResNet18_Weights
import torchvision.transforms.v2 as transforms
import torch.nn as nn
from kvasir_capsule_dataset import get_dataloader
from trainer import train_kc_model
import gc
from custom_transforms import GaussianBlur, RandomChoiceExtended
import os

with_gpu = torch.cuda.is_available()

if with_gpu:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
print('We are now using %s.' % device)

train_path = r"C:\Users\JadHa\Desktop\Uni\DLMB\DLMI-Project\kvasir-capsule-labeled-images\dataset_train.csv"
val_path = r"C:\Users\JadHa\Desktop\Uni\DLMB\DLMI-Project\kvasir-capsule-labeled-images\dataset_test.csv"
dataset_path = r"C:\Users\JadHa\Desktop\Uni\DLMB\DLMI-Project\kvasir-capsule-labeled-images\labelled_images"



We are now using cuda.


# Training with Contrastive Learning (SimCLR)

In [11]:
from simclr import SimCLR
from model import ResNetSimCLR
transforms_list = [transforms.Compose([transforms.RandomResizedCrop(size=96),transforms.Resize(size=[336, 336])]),
                   transforms.RandomRotation(degrees=360),
                   GaussianBlur(kernel_size=3),
                   transforms.ColorJitter(0.1, 0.1, 0.003, 0.003)]

transform = RandomChoiceExtended(transforms_list, min_transforms=0, max_transforms=3)

dataset_len = 39201

train_loader = get_dataloader(csv_path=train_path, 
                              dataset_path=dataset_path, 
                              batch_size=64, 
                              use_preloaded=True, 
                              start_idx=0, 
                              end_idx=int(2**13)+4096, 
                              shuffle=True, 
                              transforms=transform, 
                              drop_data_till_balanced=False, 
                              simclr=True, 
                              pin_memory=True, 
                              drop_last=True)

In [43]:
model = ResNetSimCLR(base_model="resnet18", out_dim=128)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                        last_epoch=-1)

args = {
    "device" : torch.device("cuda"),
    "batch_size" : 64,
    "n_views" : 2,
    "temperature" : 0.07,
    "fp16_precision" : True,
    "epochs" : 200,
    "log_every_n_steps" : 100,
    "arch" : "resnet18",
    "save_every_n_epochs" : 5
}

#  It’s a no-op if the 'gpu_index' argument is a negative integer or None.
with torch.cuda.device(0):
    torch.cuda.empty_cache()
    gc.collect()
    simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args)
    simclr.train(train_loader)

100%|██████████| 192/192 [07:06<00:00,  2.22s/it]


Epoch: 0	Loss: 2.2531490325927734


100%|██████████| 192/192 [07:17<00:00,  2.28s/it]


Epoch: 1	Loss: 1.7051080465316772


100%|██████████| 192/192 [07:16<00:00,  2.27s/it]


Epoch: 2	Loss: 1.1215505599975586


100%|██████████| 192/192 [07:03<00:00,  2.21s/it]


Epoch: 3	Loss: 1.972456693649292


100%|██████████| 192/192 [06:56<00:00,  2.17s/it]


Epoch: 4	Loss: 0.9122457504272461


100%|██████████| 192/192 [07:03<00:00,  2.21s/it]


Epoch: 5	Loss: 0.7935145497322083


100%|██████████| 192/192 [07:00<00:00,  2.19s/it]


Epoch: 6	Loss: 1.0135186910629272


100%|██████████| 192/192 [06:57<00:00,  2.17s/it]


Epoch: 7	Loss: 0.9064102172851562


100%|██████████| 192/192 [07:01<00:00,  2.19s/it]


Epoch: 8	Loss: 0.5957052707672119


100%|██████████| 192/192 [07:11<00:00,  2.25s/it]


Epoch: 9	Loss: 1.4095003604888916


100%|██████████| 192/192 [07:01<00:00,  2.19s/it]


Epoch: 10	Loss: 0.5840535163879395


100%|██████████| 192/192 [07:01<00:00,  2.20s/it]


Epoch: 11	Loss: 0.5167120099067688


100%|██████████| 192/192 [07:23<00:00,  2.31s/it]


Epoch: 12	Loss: 0.778656005859375


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 13	Loss: 0.4586251974105835


100%|██████████| 192/192 [07:02<00:00,  2.20s/it]


Epoch: 14	Loss: 1.5423729419708252


100%|██████████| 192/192 [06:55<00:00,  2.16s/it]


Epoch: 15	Loss: 0.7708524465560913


100%|██████████| 192/192 [06:59<00:00,  2.18s/it]


Epoch: 16	Loss: 0.6036813855171204


100%|██████████| 192/192 [06:59<00:00,  2.19s/it]


Epoch: 17	Loss: 0.8450571298599243


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 18	Loss: 0.5954111218452454


100%|██████████| 192/192 [07:01<00:00,  2.19s/it]


Epoch: 19	Loss: 1.7102864980697632


100%|██████████| 192/192 [07:03<00:00,  2.21s/it]


Epoch: 20	Loss: 0.885287880897522


100%|██████████| 192/192 [07:01<00:00,  2.19s/it]


Epoch: 21	Loss: 2.7145557403564453


100%|██████████| 192/192 [07:18<00:00,  2.28s/it]


Epoch: 22	Loss: 0.44392797350883484


100%|██████████| 192/192 [07:02<00:00,  2.20s/it]


Epoch: 23	Loss: 0.29819566011428833


100%|██████████| 192/192 [06:53<00:00,  2.15s/it]


Epoch: 24	Loss: 0.71708083152771


100%|██████████| 192/192 [06:55<00:00,  2.16s/it]


Epoch: 25	Loss: 0.7494648694992065


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 26	Loss: 1.4733991622924805


100%|██████████| 192/192 [06:59<00:00,  2.18s/it]


Epoch: 27	Loss: 0.6344439387321472


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 28	Loss: 0.9277408719062805


100%|██████████| 192/192 [06:55<00:00,  2.16s/it]


Epoch: 29	Loss: 0.3448299467563629


100%|██████████| 192/192 [06:53<00:00,  2.15s/it]


Epoch: 30	Loss: 0.8180556297302246


100%|██████████| 192/192 [07:00<00:00,  2.19s/it]


Epoch: 31	Loss: 0.33386528491973877


100%|██████████| 192/192 [07:15<00:00,  2.27s/it]


Epoch: 32	Loss: 1.0371906757354736


100%|██████████| 192/192 [07:07<00:00,  2.23s/it]


Epoch: 33	Loss: 0.8706024289131165


100%|██████████| 192/192 [06:57<00:00,  2.18s/it]


Epoch: 34	Loss: 0.707263708114624


100%|██████████| 192/192 [06:59<00:00,  2.19s/it]


Epoch: 35	Loss: 0.49317753314971924


100%|██████████| 192/192 [06:57<00:00,  2.18s/it]


Epoch: 36	Loss: 0.7868250608444214


100%|██████████| 192/192 [06:55<00:00,  2.17s/it]


Epoch: 37	Loss: 1.3775726556777954


100%|██████████| 192/192 [06:57<00:00,  2.18s/it]


Epoch: 38	Loss: 1.232825517654419


100%|██████████| 192/192 [07:01<00:00,  2.20s/it]


Epoch: 39	Loss: 0.3615981936454773


100%|██████████| 192/192 [06:57<00:00,  2.17s/it]


Epoch: 40	Loss: 0.3669615685939789


100%|██████████| 192/192 [07:01<00:00,  2.20s/it]


Epoch: 41	Loss: 1.2365336418151855


100%|██████████| 192/192 [07:13<00:00,  2.26s/it]


Epoch: 42	Loss: 0.4501655697822571


100%|██████████| 192/192 [07:03<00:00,  2.21s/it]


Epoch: 43	Loss: 1.322818398475647


100%|██████████| 192/192 [07:00<00:00,  2.19s/it]


Epoch: 44	Loss: 0.478716641664505


100%|██████████| 192/192 [06:57<00:00,  2.18s/it]


Epoch: 45	Loss: 0.8350521922111511


100%|██████████| 192/192 [06:54<00:00,  2.16s/it]


Epoch: 46	Loss: 0.37174293398857117


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 47	Loss: 0.3052789568901062


100%|██████████| 192/192 [06:57<00:00,  2.18s/it]


Epoch: 48	Loss: 0.4849785566329956


100%|██████████| 192/192 [07:06<00:00,  2.22s/it]


Epoch: 49	Loss: 0.41147351264953613


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 50	Loss: 0.5000397562980652


100%|██████████| 192/192 [07:07<00:00,  2.23s/it]


Epoch: 51	Loss: 1.0626264810562134


100%|██████████| 192/192 [07:18<00:00,  2.29s/it]


Epoch: 52	Loss: 1.0317132472991943


100%|██████████| 192/192 [06:56<00:00,  2.17s/it]


Epoch: 53	Loss: 0.2687304615974426


100%|██████████| 192/192 [06:59<00:00,  2.19s/it]


Epoch: 54	Loss: 0.1234329342842102


100%|██████████| 192/192 [06:57<00:00,  2.18s/it]


Epoch: 55	Loss: 0.2005760222673416


100%|██████████| 192/192 [07:01<00:00,  2.19s/it]


Epoch: 56	Loss: 0.6462100744247437


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 57	Loss: 1.1432130336761475


100%|██████████| 192/192 [06:57<00:00,  2.18s/it]


Epoch: 58	Loss: 0.30191561579704285


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 59	Loss: 1.184007167816162


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 60	Loss: 0.3576620817184448


100%|██████████| 192/192 [07:06<00:00,  2.22s/it]


Epoch: 61	Loss: 0.31897106766700745


100%|██████████| 192/192 [07:12<00:00,  2.25s/it]


Epoch: 62	Loss: 1.1469839811325073


100%|██████████| 192/192 [06:57<00:00,  2.18s/it]


Epoch: 63	Loss: 0.6231482028961182


100%|██████████| 192/192 [07:02<00:00,  2.20s/it]


Epoch: 64	Loss: 0.6482008695602417


100%|██████████| 192/192 [07:04<00:00,  2.21s/it]


Epoch: 65	Loss: 0.6344423294067383


100%|██████████| 192/192 [06:56<00:00,  2.17s/it]


Epoch: 66	Loss: 0.41362881660461426


100%|██████████| 192/192 [06:59<00:00,  2.19s/it]


Epoch: 67	Loss: 0.4081663489341736


100%|██████████| 192/192 [07:02<00:00,  2.20s/it]


Epoch: 68	Loss: 0.4168906509876251


100%|██████████| 192/192 [06:48<00:00,  2.13s/it]


Epoch: 69	Loss: 0.6517927646636963


100%|██████████| 192/192 [06:45<00:00,  2.11s/it]


Epoch: 70	Loss: 0.9837464690208435


100%|██████████| 192/192 [06:47<00:00,  2.12s/it]


Epoch: 71	Loss: 0.28088605403900146


100%|██████████| 192/192 [06:57<00:00,  2.17s/it]


Epoch: 72	Loss: 0.49288368225097656


100%|██████████| 192/192 [06:55<00:00,  2.16s/it]


Epoch: 73	Loss: 0.16639167070388794


100%|██████████| 192/192 [06:40<00:00,  2.09s/it]


Epoch: 74	Loss: 0.47368571162223816


100%|██████████| 192/192 [06:43<00:00,  2.10s/it]


Epoch: 75	Loss: 0.37473064661026


100%|██████████| 192/192 [06:48<00:00,  2.13s/it]


Epoch: 76	Loss: 0.15805402398109436


100%|██████████| 192/192 [07:01<00:00,  2.20s/it]


Epoch: 77	Loss: 0.4914311170578003


100%|██████████| 192/192 [06:57<00:00,  2.17s/it]


Epoch: 78	Loss: 1.2932891845703125


100%|██████████| 192/192 [07:00<00:00,  2.19s/it]


Epoch: 79	Loss: 0.28116363286972046


100%|██████████| 192/192 [06:55<00:00,  2.16s/it]


Epoch: 80	Loss: 1.7395963668823242


100%|██████████| 192/192 [06:57<00:00,  2.17s/it]


Epoch: 81	Loss: 0.26812246441841125


100%|██████████| 192/192 [06:50<00:00,  2.14s/it]


Epoch: 82	Loss: 0.5526418089866638


100%|██████████| 192/192 [06:53<00:00,  2.15s/it]


Epoch: 83	Loss: 2.0587191581726074


100%|██████████| 192/192 [07:03<00:00,  2.21s/it]


Epoch: 84	Loss: 1.7523753643035889


100%|██████████| 192/192 [06:49<00:00,  2.13s/it]


Epoch: 85	Loss: 0.32045382261276245


100%|██████████| 192/192 [06:44<00:00,  2.11s/it]


Epoch: 86	Loss: 0.39758551120758057


100%|██████████| 192/192 [06:43<00:00,  2.10s/it]


Epoch: 87	Loss: 0.3469770848751068


100%|██████████| 192/192 [06:47<00:00,  2.12s/it]


Epoch: 88	Loss: 1.7217386960983276


100%|██████████| 192/192 [06:42<00:00,  2.10s/it]


Epoch: 89	Loss: 9.14508056640625


100%|██████████| 192/192 [06:46<00:00,  2.12s/it]


Epoch: 90	Loss: 0.4481651782989502


100%|██████████| 192/192 [06:40<00:00,  2.09s/it]


Epoch: 91	Loss: 1.1797230243682861


100%|██████████| 192/192 [06:42<00:00,  2.09s/it]


Epoch: 92	Loss: 1.6793129444122314


100%|██████████| 192/192 [06:45<00:00,  2.11s/it]


Epoch: 93	Loss: 0.2606317400932312


100%|██████████| 192/192 [07:01<00:00,  2.20s/it]


Epoch: 94	Loss: 0.22142867743968964


100%|██████████| 192/192 [06:54<00:00,  2.16s/it]


Epoch: 95	Loss: 0.20613542199134827


100%|██████████| 192/192 [06:42<00:00,  2.09s/it]


Epoch: 96	Loss: 0.2448015809059143


100%|██████████| 192/192 [06:42<00:00,  2.10s/it]


Epoch: 97	Loss: 0.16851909458637238


100%|██████████| 192/192 [06:41<00:00,  2.09s/it]


Epoch: 98	Loss: 0.2897515892982483


100%|██████████| 192/192 [06:42<00:00,  2.10s/it]


Epoch: 99	Loss: 0.3424556255340576


100%|██████████| 192/192 [06:39<00:00,  2.08s/it]


Epoch: 100	Loss: 0.2465226650238037


100%|██████████| 192/192 [06:46<00:00,  2.12s/it]


Epoch: 101	Loss: 0.7001491785049438


100%|██████████| 192/192 [06:43<00:00,  2.10s/it]


Epoch: 102	Loss: 0.18870383501052856


100%|██████████| 192/192 [06:44<00:00,  2.11s/it]


Epoch: 103	Loss: 0.3873102068901062


100%|██████████| 192/192 [06:48<00:00,  2.13s/it]


Epoch: 104	Loss: 0.31436777114868164


100%|██████████| 192/192 [07:06<00:00,  2.22s/it]


Epoch: 105	Loss: 0.5171740055084229


100%|██████████| 192/192 [06:38<00:00,  2.08s/it]


Epoch: 106	Loss: 0.33017081022262573


100%|██████████| 192/192 [06:45<00:00,  2.11s/it]


Epoch: 107	Loss: 3.7625908851623535


100%|██████████| 192/192 [06:41<00:00,  2.09s/it]


Epoch: 108	Loss: 0.21616104245185852


100%|██████████| 192/192 [06:37<00:00,  2.07s/it]


Epoch: 109	Loss: 0.2347840964794159


100%|██████████| 192/192 [06:43<00:00,  2.10s/it]


Epoch: 110	Loss: 0.3277970850467682


100%|██████████| 192/192 [06:46<00:00,  2.12s/it]


Epoch: 111	Loss: 0.23358601331710815


100%|██████████| 192/192 [06:44<00:00,  2.11s/it]


Epoch: 112	Loss: 2.5940425395965576


100%|██████████| 192/192 [06:46<00:00,  2.12s/it]


Epoch: 113	Loss: 0.17243659496307373


100%|██████████| 192/192 [06:50<00:00,  2.14s/it]


Epoch: 114	Loss: 0.21876731514930725


100%|██████████| 192/192 [06:59<00:00,  2.19s/it]


Epoch: 115	Loss: 1.6632837057113647


100%|██████████| 192/192 [06:59<00:00,  2.18s/it]


Epoch: 116	Loss: 0.48003578186035156


100%|██████████| 192/192 [06:43<00:00,  2.10s/it]


Epoch: 117	Loss: 0.1141618937253952


100%|██████████| 192/192 [06:41<00:00,  2.09s/it]


Epoch: 118	Loss: 0.2628929615020752


100%|██████████| 192/192 [06:46<00:00,  2.12s/it]


Epoch: 119	Loss: 0.053460247814655304


100%|██████████| 192/192 [06:45<00:00,  2.11s/it]


Epoch: 120	Loss: 0.16094531118869781


100%|██████████| 192/192 [06:49<00:00,  2.13s/it]


Epoch: 121	Loss: 0.24290606379508972


100%|██████████| 192/192 [06:46<00:00,  2.12s/it]


Epoch: 122	Loss: 1.1804393529891968


100%|██████████| 192/192 [06:41<00:00,  2.09s/it]


Epoch: 123	Loss: 0.7236275672912598


100%|██████████| 192/192 [06:46<00:00,  2.12s/it]


Epoch: 124	Loss: 0.23613902926445007


100%|██████████| 192/192 [06:39<00:00,  2.08s/it]


Epoch: 125	Loss: 0.47387516498565674


100%|██████████| 192/192 [07:03<00:00,  2.20s/it]


Epoch: 126	Loss: 0.24783849716186523


100%|██████████| 192/192 [06:49<00:00,  2.13s/it]


Epoch: 127	Loss: 0.20300257205963135


100%|██████████| 192/192 [06:44<00:00,  2.11s/it]


Epoch: 128	Loss: 1.0653536319732666


100%|██████████| 192/192 [06:43<00:00,  2.10s/it]


Epoch: 129	Loss: 3.8481898307800293


100%|██████████| 192/192 [06:45<00:00,  2.11s/it]


Epoch: 130	Loss: 1.0757293701171875


100%|██████████| 192/192 [06:45<00:00,  2.11s/it]


Epoch: 131	Loss: 0.5935533046722412


100%|██████████| 192/192 [06:45<00:00,  2.11s/it]


Epoch: 132	Loss: 0.15170511603355408


100%|██████████| 192/192 [06:47<00:00,  2.12s/it]


Epoch: 133	Loss: 0.36051803827285767


100%|██████████| 192/192 [06:48<00:00,  2.13s/it]


Epoch: 134	Loss: 0.15562091767787933


100%|██████████| 192/192 [06:47<00:00,  2.12s/it]


Epoch: 135	Loss: 0.24977153539657593


100%|██████████| 192/192 [06:49<00:00,  2.13s/it]


Epoch: 136	Loss: 0.21761320531368256


100%|██████████| 192/192 [06:54<00:00,  2.16s/it]


Epoch: 137	Loss: 0.33152493834495544


100%|██████████| 192/192 [06:52<00:00,  2.15s/it]


Epoch: 138	Loss: 0.1452987939119339


100%|██████████| 192/192 [06:47<00:00,  2.12s/it]


Epoch: 139	Loss: 0.2454790323972702


100%|██████████| 192/192 [06:47<00:00,  2.12s/it]


Epoch: 140	Loss: 0.1890789270401001


100%|██████████| 192/192 [06:53<00:00,  2.15s/it]


Epoch: 141	Loss: 0.2432636320590973


100%|██████████| 192/192 [06:43<00:00,  2.10s/it]


Epoch: 142	Loss: 0.5444148778915405


100%|██████████| 192/192 [06:52<00:00,  2.15s/it]


Epoch: 143	Loss: 0.3142277002334595


100%|██████████| 192/192 [06:41<00:00,  2.09s/it]


Epoch: 144	Loss: 0.29236719012260437


100%|██████████| 192/192 [06:42<00:00,  2.09s/it]


Epoch: 145	Loss: 0.1549953818321228


100%|██████████| 192/192 [06:41<00:00,  2.09s/it]


Epoch: 146	Loss: 1.0272748470306396


100%|██████████| 192/192 [06:46<00:00,  2.12s/it]


Epoch: 147	Loss: 0.9591196775436401


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 148	Loss: 0.44842880964279175


100%|██████████| 192/192 [06:41<00:00,  2.09s/it]


Epoch: 149	Loss: 0.2753964066505432


100%|██████████| 192/192 [06:41<00:00,  2.09s/it]


Epoch: 150	Loss: 0.6304421424865723


100%|██████████| 192/192 [06:42<00:00,  2.10s/it]


Epoch: 151	Loss: 0.5514675378799438


100%|██████████| 192/192 [06:40<00:00,  2.08s/it]


Epoch: 152	Loss: 1.6879370212554932


100%|██████████| 192/192 [07:03<00:00,  2.21s/it]


Epoch: 153	Loss: 0.3659796118736267


100%|██████████| 192/192 [07:00<00:00,  2.19s/it]


Epoch: 154	Loss: 0.44822394847869873


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 155	Loss: 0.36173015832901


100%|██████████| 192/192 [06:47<00:00,  2.12s/it]


Epoch: 156	Loss: 0.29125046730041504


100%|██████████| 192/192 [06:45<00:00,  2.11s/it]


Epoch: 157	Loss: 0.16686101257801056


100%|██████████| 192/192 [06:50<00:00,  2.14s/it]


Epoch: 158	Loss: 0.23924186825752258


100%|██████████| 192/192 [06:59<00:00,  2.18s/it]


Epoch: 159	Loss: 0.15904201567173004


100%|██████████| 192/192 [06:50<00:00,  2.14s/it]


Epoch: 160	Loss: 0.12698431313037872


100%|██████████| 192/192 [06:41<00:00,  2.09s/it]


Epoch: 161	Loss: 0.9029024839401245


100%|██████████| 192/192 [06:45<00:00,  2.11s/it]


Epoch: 162	Loss: 1.0250048637390137


100%|██████████| 192/192 [07:03<00:00,  2.21s/it]


Epoch: 163	Loss: 0.9215497374534607


100%|██████████| 192/192 [07:32<00:00,  2.36s/it]


Epoch: 164	Loss: 0.25962549448013306


100%|██████████| 192/192 [07:09<00:00,  2.23s/it]


Epoch: 165	Loss: 0.38533496856689453


100%|██████████| 192/192 [07:01<00:00,  2.19s/it]


Epoch: 166	Loss: 0.9928542375564575


100%|██████████| 192/192 [07:18<00:00,  2.28s/it]


Epoch: 167	Loss: 0.4902390241622925


100%|██████████| 192/192 [07:13<00:00,  2.26s/it]


Epoch: 168	Loss: 0.5440274477005005


100%|██████████| 192/192 [07:44<00:00,  2.42s/it]


Epoch: 169	Loss: 0.19910025596618652


100%|██████████| 192/192 [07:41<00:00,  2.40s/it]


Epoch: 170	Loss: 1.0421059131622314


100%|██████████| 192/192 [07:20<00:00,  2.29s/it]


Epoch: 171	Loss: 3.136303424835205


100%|██████████| 192/192 [07:17<00:00,  2.28s/it]


Epoch: 172	Loss: 2.54754638671875


100%|██████████| 192/192 [07:33<00:00,  2.36s/it]


Epoch: 173	Loss: 0.22153633832931519


100%|██████████| 192/192 [07:08<00:00,  2.23s/it]


Epoch: 174	Loss: 0.4861600399017334


100%|██████████| 192/192 [07:25<00:00,  2.32s/it]


Epoch: 175	Loss: 0.13101354241371155


100%|██████████| 192/192 [07:22<00:00,  2.31s/it]


Epoch: 176	Loss: 1.67152738571167


100%|██████████| 192/192 [07:13<00:00,  2.26s/it]


Epoch: 177	Loss: 0.5731610059738159


100%|██████████| 192/192 [07:07<00:00,  2.22s/it]


Epoch: 178	Loss: 0.5454331040382385


100%|██████████| 192/192 [07:05<00:00,  2.22s/it]


Epoch: 179	Loss: 0.1548595279455185


100%|██████████| 192/192 [07:13<00:00,  2.26s/it]


Epoch: 180	Loss: 0.5858038663864136


100%|██████████| 192/192 [07:19<00:00,  2.29s/it]


Epoch: 181	Loss: 0.3487468659877777


100%|██████████| 192/192 [07:16<00:00,  2.28s/it]


Epoch: 182	Loss: 0.47178342938423157


100%|██████████| 192/192 [06:56<00:00,  2.17s/it]


Epoch: 183	Loss: 0.5465964078903198


100%|██████████| 192/192 [07:03<00:00,  2.20s/it]


Epoch: 184	Loss: 0.12119133770465851


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 185	Loss: 0.15868666768074036


100%|██████████| 192/192 [07:00<00:00,  2.19s/it]


Epoch: 186	Loss: 0.5785863399505615


100%|██████████| 192/192 [06:59<00:00,  2.19s/it]


Epoch: 187	Loss: 0.22854235768318176


100%|██████████| 192/192 [07:00<00:00,  2.19s/it]


Epoch: 188	Loss: 0.18879181146621704


100%|██████████| 192/192 [07:00<00:00,  2.19s/it]


Epoch: 189	Loss: 0.22569970786571503


100%|██████████| 192/192 [06:57<00:00,  2.17s/it]


Epoch: 190	Loss: 0.5195077657699585


100%|██████████| 192/192 [07:02<00:00,  2.20s/it]


Epoch: 191	Loss: 0.24797110259532928


100%|██████████| 192/192 [07:04<00:00,  2.21s/it]


Epoch: 192	Loss: 0.3544115126132965


100%|██████████| 192/192 [06:58<00:00,  2.18s/it]


Epoch: 193	Loss: 0.3259449899196625


100%|██████████| 192/192 [06:59<00:00,  2.19s/it]


Epoch: 194	Loss: 0.19518086314201355


100%|██████████| 192/192 [06:59<00:00,  2.18s/it]


Epoch: 195	Loss: 1.6723631620407104


100%|██████████| 192/192 [06:59<00:00,  2.18s/it]


Epoch: 196	Loss: 1.686869502067566


100%|██████████| 192/192 [07:00<00:00,  2.19s/it]


Epoch: 197	Loss: 0.20794028043746948


100%|██████████| 192/192 [06:56<00:00,  2.17s/it]


Epoch: 198	Loss: 0.527816116809845


100%|██████████| 192/192 [06:56<00:00,  2.17s/it]


Epoch: 199	Loss: 0.3368372619152069


# Finetuning pretrained model

In [4]:
from model import ResNetSimCLR

dataset_len = 39201
train_loader = get_dataloader(csv_path=train_path, 
                              dataset_path=dataset_path, 
                              batch_size=128, 
                              use_preloaded=True, 
                              start_idx=0, 
                              end_idx=dataset_len-1, 
                              shuffle=True, 
                              transforms=None, 
                              drop_data_till_balanced=True, 
                              simclr=False)

val_loader = get_dataloader(csv_path=val_path, dataset_path=dataset_path, batch_size=128, shuffle=True, use_preloaded=True, start_idx=0, end_idx=8047, simclr=False)

In [18]:
chkpoint = torch.load(os.path.join(r"runs\Jan25_03-08-13_DESKTOP-K6APQCL", "checkpoint_0200.pt"))

model = ResNetSimCLR(base_model="resnet18", out_dim=128)
model.load_state_dict(chkpoint["state_dict"])

for param in model.parameters(): # Freeze all parameters
    param.requires_grad = False

model.backbone.fc = nn.Linear(512, 2) # Replace fc layer from SimCLR with a new one
model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

scaler = torch.cuda.amp.GradScaler()
criterion = nn.CrossEntropyLoss(weight=torch.tensor([1, 9]).to(device))

model_name = "ResNet_SimCLR"

In [None]:
chkpoint = torch.load(os.path.join("saved_models", "%s.pt"%model_name))
model.load_state_dict(chkpoint["model_state_dict"])
optimizer.load_state_dict(chkpoint["optimizer"])
scaler.load_state_dict(chkpoint["scaler"])

In [19]:
torch.cuda.empty_cache()
gc.collect()
train_kc_model(model, optimizer, criterion, train_loader, val_loader, scaler, model_name=model_name, epochs=40, device=device)

Training :
Epoch : 1, Loss : 1.821, Accuracy : 0.500, Precision : 0.500, Recall : 1.000
Validating :
Epoch : 1, Loss : 1.863, Accuracy : 0.004, Precision : 0.004, Recall : 1.000
Model is stored at folder:saved_models/ResNet_SimCLR.pt
Training :
Epoch : 2, Loss : 1.570, Accuracy : 0.500, Precision : 0.500, Recall : 1.000
Validating :
Epoch : 2, Loss : 1.933, Accuracy : 0.004, Precision : 0.004, Recall : 1.000
Model is stored at folder:saved_models/ResNet_SimCLR.pt
Training :
Epoch : 3, Loss : 1.194, Accuracy : 0.502, Precision : 0.501, Recall : 1.000
Validating :
Epoch : 3, Loss : 1.612, Accuracy : 0.009, Precision : 0.004, Recall : 1.000
Model is stored at folder:saved_models/ResNet_SimCLR.pt
Training :
Epoch : 4, Loss : 0.965, Accuracy : 0.535, Precision : 0.518, Recall : 1.000
Validating :
Epoch : 4, Loss : 1.056, Accuracy : 0.205, Precision : 0.005, Recall : 1.000
Model is stored at folder:saved_models/ResNet_SimCLR.pt
Training :
Epoch : 5, Loss : 0.849, Accuracy : 0.685, Precision 

# Training with normal supervised Learning

In [None]:
transforms_list = [transforms.Compose([transforms.RandomResizedCrop(size=96),transforms.Resize(size=[336, 336])]),
                   transforms.RandomRotation(degrees=360),
                   GaussianBlur(kernel_size=3),
                   transforms.ColorJitter(0.1, 0.1, 0.003, 0.003)]

transform = RandomChoiceExtended(transforms_list, min_transforms=0, max_transforms=3)

dataset_len = 39201

train_loader = get_dataloader(csv_path=train_path, dataset_path=dataset_path, batch_size=128, use_preloaded=True, start_idx=4096, end_idx=int(2**13)+4096, shuffle=True, transforms=transform, drop_data_till_balanced=False, simclr=True)
val_loader = get_dataloader(csv_path=val_path, dataset_path=dataset_path, batch_size=128, shuffle=True, use_preloaded=True, start_idx=0, end_idx=8047, simclr=False)

In [None]:
resnet = resnet18().to(device) # weights=ResNet18_Weights.IMAGENET1K_V1

resnet.fc = nn.Linear(512, 2).to(device)

optimizer = optim.AdamW(params=resnet.parameters(), lr=3e-4)
scaler = torch.cuda.amp.GradScaler()

criterion = nn.CrossEntropyLoss() # weight=torch.tensor([1, 20]).to(device)

In [None]:
torch.cuda.empty_cache()
gc.collect()
train_kc_model(resnet, optimizer, criterion, train_loader, val_loader, scaler, model_name="Resnet_AMP", epochs=40, device=device)



Training :
Epoch : 1, Loss : 0.017, Accuracy : 0.968, Precision : 0.657, Recall : 0.804
Validating :
Epoch : 1, Loss : 0.065, Accuracy : 0.949, Precision : 0.008, Recall : 0.088
Model is stored at folder:saved_models/Resnet_AMP.pt
Training :
Epoch : 2, Loss : 0.023, Accuracy : 0.989, Precision : 0.921, Recall : 0.851
Validating :
Epoch : 2, Loss : 0.215, Accuracy : 0.940, Precision : 0.007, Recall : 0.088
Model is stored at folder:saved_models/Resnet_AMP.pt
Training :
Epoch : 3, Loss : 0.004, Accuracy : 0.991, Precision : 0.942, Recall : 0.884
Validating :
Epoch : 3, Loss : 0.372, Accuracy : 0.941, Precision : 0.007, Recall : 0.088
Model is stored at folder:saved_models/Resnet_AMP.pt
Training :
Epoch : 4, Loss : 0.007, Accuracy : 0.993, Precision : 0.951, Recall : 0.917
Validating :
Epoch : 4, Loss : 0.794, Accuracy : 0.941, Precision : 0.007, Recall : 0.088
Model is stored at folder:saved_models/Resnet_AMP.pt
Training :
Epoch : 5, Loss : 0.008, Accuracy : 0.994, Precision : 0.954, Rec

KeyboardInterrupt: 

# Training an XGBoost Classifier

In [13]:
from kvasir_capsule_dataset import KC_Dataset_Features
dataset_len = 39201
train_dataset = KC_Dataset_Features(start_idx=0, end_idx=dataset_len, csv_path=train_path, dataset_path=dataset_path)
val_dataset = KC_Dataset_Features(start_idx=0, end_idx=8047, csv_path=val_path, dataset_path=dataset_path)

100%|██████████| 39201/39201 [2:01:32<00:00,  5.38it/s]  
100%|██████████| 8047/8047 [23:54<00:00,  5.61it/s]


In [22]:
import xgboost as xgb

X_train, y_train = (train_dataset.x, train_dataset.y)
X_test, y_test = (val_dataset.x, val_dataset.y)

# Use "hist" for constructing the trees, with early stopping enabled.
clf = xgb.XGBClassifier(tree_method="hist", early_stopping_rounds=2, scale_pos_weight=1/100)
# Fit the model, test sets are used for early stopping.
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
# Save model into JSON format.
clf.save_model("xgboost_3.json")

preds = clf.predict(X_test)

[0]	validation_0-logloss:0.58344
[1]	validation_0-logloss:0.40206
[2]	validation_0-logloss:0.29162
[3]	validation_0-logloss:0.25121
[4]	validation_0-logloss:0.22556
[5]	validation_0-logloss:0.20624
[6]	validation_0-logloss:0.19864
[7]	validation_0-logloss:0.19456
[8]	validation_0-logloss:0.19351
[9]	validation_0-logloss:0.19281
[10]	validation_0-logloss:0.19485
[11]	validation_0-logloss:0.19263
[12]	validation_0-logloss:0.18566
[13]	validation_0-logloss:0.18701
[14]	validation_0-logloss:0.18977


In [18]:
clf = xgb.XGBClassifier(tree_method="hist", early_stopping_rounds=2)
clf.load_model("clf.json")
preds = clf.predict(X_test)

In [23]:
from sklearn.metrics import classification_report
import numpy as np
print(classification_report(np.argmax(y_test, axis=1), np.argmax(preds, axis=1), target_names=["normal", "blood"]))
print(np.sum(np.logical_and(np.argmax(y_test, axis=1) == 1, np.argmax(preds, axis=1) == 1)) / np.sum(np.argmax(y_test, axis=1)))
print(np.sum(np.logical_and(np.argmax(y_test, axis=1) == 1, np.argmax(preds, axis=1) == 1)) / np.sum(np.argmax(preds, axis=1)))

              precision    recall  f1-score   support

      normal       1.00      0.97      0.98      8013
       blood       0.00      0.00      0.00        34

    accuracy                           0.97      8047
   macro avg       0.50      0.49      0.49      8047
weighted avg       0.99      0.97      0.98      8047

0.0
0.0
