# ResNet Train-Validation

In [1]:
import torch
from torchvision.transforms import (
    Compose, ToTensor, Resize, Normalize, RandomResizedCrop, RandomHorizontalFlip, RandomVerticalFlip
)
from torchvision.models import resnet50

from data import HerniaDataset, PhaseMapper
from utils import ResnetTrainer, ResnetEvaluator

import random

%load_ext autoreload
%autoreload 2

## Dataset

In [2]:
root = '../../surgery_hernia_train_test/'
videos = ['RALIHR_surgeon01_fps01_{:04}'.format(i + 1) for i in range(120)]
random.shuffle(videos)

In [3]:
mapper = PhaseMapper('../configs/all_labels_hernia_merged_7.csv')
mapper.get_merged_labels()

Unnamed: 0,labels
0,mesh placement
1,out of body
2,peritoneal closure
3,peritoneal scoring
4,preperioneal dissection
5,reduction of hernia
6,transitionary idle


In [4]:
input_shape = (224, 224)
mean = [0.41757566,0.26098573,0.25888634]
std = [0.21938758,0.1983,0.19342837]
train_transform = Compose([
    #Resize(input_shape),
    RandomResizedCrop(size=input_shape, scale=(0.8,1.0), ratio=(1.0,1.0)),
    ToTensor(),
    RandomHorizontalFlip(),
    Normalize(mean, std)
])
valid_transform = Compose([
    Resize(input_shape),
    ToTensor(),
    Normalize(mean, std)
])
train_set = HerniaDataset(root, videos[:50], transforms=train_transform, class_map=mapper)
valid_set = HerniaDataset(root, videos[50:70], transforms=valid_transform, class_map=mapper)
test_set = HerniaDataset(root, videos[70:120], transforms=valid_transform, class_map=mapper)
len(train_set), len(valid_set), len(test_set)

(143243, 71814, 147205)

## Model

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [6]:
model = resnet50(pretrained=True)
model.fc = torch.nn.Linear(2048, 7)

## Training

In [7]:
trainer = ResnetTrainer(model, device)

In [8]:
hist = trainer.train(
    train_set, valid_set, 
    num_epochs=10, 
    batch_size=64, 
    learning_rate=1e-5, 
    run_name='resnet50-p7-v120-b64-lr1em5-a', 
    num_workers=1, 
    prefetch_factor=4
)
hist

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mharrypotter1501[0m ([33meezklab[0m). Use [1m`wandb login --relogin`[0m to force relogin


Running resnet50-p7-v120-b64-lr1em5-a
Datasets: num_train = 143243, num_validation = 71814
Main metric: valid_f1_weighted
Epoch 1/10: validating 99.91095280498664% (batch 1123/1123)



Epoch 1/10: train_loss 0.48985117, train_accuracy 0.82643480, train_f1_macro 0.77004582, train_f1_weighted 0.82167837, valid_loss 0.54256631, valid_accuracy 0.79801988, valid_f1_macro 0.77796166, valid_f1_weighted 0.79430813, _timestamp 1658960373.00000000, _runtime 4822.00000000
Epoch 2/10: training 45.20% (batch 1013/2239)

wandb: Network error (ReadTimeout), entering retry loop.


Epoch 2/10: train_loss 0.26736269, train_accuracy 0.90498663, train_f1_macro 0.89064892, train_f1_weighted 0.90406878, valid_loss 0.54457758, valid_accuracy 0.79459437, valid_f1_macro 0.78174065, valid_f1_weighted 0.79417457, _timestamp 1658966194.00000000, _runtime 10643.00000000
Epoch 3/10: train_loss 0.20397792, train_accuracy 0.92847120, train_f1_macro 0.91839572, train_f1_weighted 0.92799093, valid_loss 0.56635960, valid_accuracy 0.78598880, valid_f1_macro 0.77425233, valid_f1_weighted 0.78691694, _timestamp 1658971761.00000000, _runtime 16210.00000000
Epoch 4/10: train_loss 0.17234931, train_accuracy 0.93879631, train_f1_macro 0.93062004, train_f1_weighted 0.93846389, valid_loss 0.54841482, valid_accuracy 0.79391205, valid_f1_macro 0.77910861, valid_f1_weighted 0.79259199, _timestamp 1658977258.00000000, _runtime 21707.00000000
Epoch 5/10: validating 99.91095280498664% (batch 1123/1123)



Epoch 5/10: train_loss 0.14533585, train_accuracy 0.94848614, train_f1_macro 0.94097746, train_f1_weighted 0.94826077, valid_loss 0.54751464, valid_accuracy 0.79704514, valid_f1_macro 0.77718210, valid_f1_weighted 0.79583665, _timestamp 1658982690.00000000, _runtime 27139.00000000
Epoch 6/10: train_loss 0.13017060, train_accuracy 0.95395936, train_f1_macro 0.94762054, train_f1_weighted 0.95376992, valid_loss 0.64429603, valid_accuracy 0.76918150, valid_f1_macro 0.76534038, valid_f1_weighted 0.77106643, _timestamp 1658988121.00000000, _runtime 32570.00000000
Epoch 7/10: train_loss 0.11309956, train_accuracy 0.96022842, train_f1_macro 0.95453944, train_f1_weighted 0.96008869, valid_loss 0.56689893, valid_accuracy 0.79070933, valid_f1_macro 0.77795330, valid_f1_weighted 0.79165793, _timestamp 1658993380.00000000, _runtime 37829.00000000
Epoch 8/10: train_loss 0.10277900, train_accuracy 0.96383767, train_f1_macro 0.95912967, train_f1_weighted 0.96373139, valid_loss 0.61072616, valid_accura

{'train_loss': [0.4898511746090509,
  0.2673626902108482,
  0.2039779186674717,
  0.17234930998216041,
  0.14533584687381304,
  0.13017059831500266,
  0.11309956078421343,
  0.10277900118721868,
  0.0927916445214902,
  0.08597531209355815],
 'train_accuracy': [0.8264347996062635,
  0.9049866311093736,
  0.9284711992907158,
  0.9387963111635472,
  0.9484861389387265,
  0.9539593557800381,
  0.9602284230293976,
  0.963837674441334,
  0.9670629629371069,
  0.9697786279259719],
 'train_f1_macro': [0.7700458222294447,
  0.8906489160095781,
  0.9183957212294686,
  0.9306200393699527,
  0.940977462025602,
  0.9476205440725707,
  0.9545394423454917,
  0.959129666140549,
  0.9622892223862911,
  0.9654612360464035],
 'train_f1_weighted': [0.8216783664208452,
  0.9040687759173089,
  0.9279909307424321,
  0.938463887224978,
  0.9482607726167331,
  0.9537699177789725,
  0.9600886893131895,
  0.9637313933083438,
  0.9669730539190137,
  0.9697162973351067],
 'valid_loss': [0.5425663081753272,
  0.544

## Evaluation

In [10]:
model.load_state_dict(torch.load('./model/resnet50-p7-v120-b64-lr1em5-a.pt'))

<All keys matched successfully>

In [11]:
evaluator = ResnetEvaluator(model, device, mapper.get_merged_labels()['labels'])

In [12]:
res, report = evaluator.evaluate(test_set, num_workers=1, prefetch_factor=128)
res

Testing ResNet
Datasets: num_test = 147205
Testing 100.00% (batch 147201/147205)

{'time': 0.012879398337913869,
 'accuracy': 0.7696681498590401,
 'f1_macro': 0.7608279192870857,
 'f1_weighted': 0.7679778534893521}

In [13]:
print(report)

                         precision    recall  f1-score   support

         mesh placement       0.75      0.81      0.78     34092
            out of body       0.94      0.99      0.96       906
     peritoneal closure       0.87      0.86      0.87     36029
     peritoneal scoring       0.88      0.63      0.74      6745
preperioneal dissection       0.62      0.76      0.68     18409
    reduction of hernia       0.84      0.79      0.82     36725
     transitionary idle       0.55      0.43      0.48     14299

               accuracy                           0.77    147205
              macro avg       0.78      0.75      0.76    147205
           weighted avg       0.77      0.77      0.77    147205

