In [1]:
import yaml
import os
from collections import OrderedDict
from typing import Sequence, Dict, Optional, Sequence
from tqdm import tqdm
import torch
from torch.optim import Adam, SGD, RMSprop
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.utils import tensorboard

from models.classification import Classification
from classificationdataset import ClassificationDataset
from utils import *

## Load Task Config

In [2]:
configfile = open('config.yml')
config = yaml.load(configfile, Loader=yaml.FullLoader)

In [3]:
task = config['task']
taskconfig = config['tasks'][task.lower()]

In [4]:
print(f'The task defined is {task} with config {config}')


The task defined is classification with config {'task': 'classification', 'datadir': '/media/ranjan/dl/smartzoo/data', 'numepochs': 100, 'finetuneepochs': 0, 'logdir': 'logs', 'optimizer': {'name': 'rmsprop', 'alpha': 0.9, 'initlr': 0.256, 'momentum': 0.9, 'weightdecay': '1e-5'}, 'lossfn': 'crossentropyloss', 'pretrained': 'eff', 'pretrainedpath': 'logs/ckpt/ckpt_0.pt', 'batchsize': 4, 'imagesize': 300, 'tasks': {'classification': {'extractor': 'efficientnet-b3', 'numclasses': 6, 'weights': [], 'labelmapping': {'Lesser Mousedeer': 0, 'Long Tailed Macaque': 1, 'Others': 2, 'Sambar Deer': 3, 'Spotted Whistling Duck': 4, 'Wild Pig': 5}}}}


## Load Model

In [5]:
models = {
    'classification': Classification(extractor=taskconfig['extractor'], numclasses=taskconfig['numclasses'])
}

Loaded pretrained weights for efficientnet-b3


In [6]:
model = models[task]

In [7]:
startepoch = -1
if config['pretrained']  != 'imagenet' and config['pretrainedpath'] !='':
    print(f"Loading the checkpoint from {config['pretrainedpath']}")
    checkpoint = torch.load(config['pretrainedpath'], map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    startepoch = checkpoint['epoch']

Loading the checkpoint from logs/ckpt/ckpt_0.pt


In [8]:
cudaavailable = torch.cuda.is_available()
device = torch.device('cuda') if cudaavailable else torch.device('cpu')
model.to(device)

Classification(
  (encoder): EfficientNet(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 40, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
    )
    (_bn0): BatchNorm2d(40, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          40, 40, kernel_size=(3, 3), stride=[1, 1], groups=40, bias=False
          (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        )
        (_bn1): BatchNorm2d(40, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          40, 10, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          10, 40, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
     

## Define loss function

In [9]:
lossfns = {
    'crossentropyloss': torch.nn.CrossEntropyLoss(reduction='none').to(device)
}

In [10]:
lossfn = lossfns[config['lossfn']]

## Define Optimizer

In [11]:
optimizers = {
    'sgd': SGD(model.parameters(), lr=float(config['optimizer']['initlr']), 
               momentum=float(config['optimizer']['momentum']), 
               weight_decay=float(config['optimizer']['momentum'])),
    'rmsprop': RMSprop(model.parameters(), lr=float(config['optimizer']['initlr']), 
                       alpha=float(config['optimizer']['alpha']),
                       momentum=float(config['optimizer']['momentum']), 
                       weight_decay=float(config['optimizer']['weightdecay']))
}

In [12]:
optimizer = optimizers[config['optimizer']['name']]
scheduler = lr_scheduler.StepLR(
            optimizer=optimizer, step_size=1, gamma=0.97 ** (1 / 2.4))

## Load Data

In [13]:
traindataset = ClassificationDataset(datadir=config['datadir'], fold='train', 
                                     imagesize=config['imagesize'], 
                                    labelmapping=taskconfig['labelmapping'])
traindataloader = DataLoader(traindataset, 
                             shuffle=True, 
                             pin_memory=True, 
                             num_workers=6, 
                             batch_size=config['batchsize'])

In [14]:
valdataset = ClassificationDataset(datadir=config['datadir'], fold='valid', 
                                   imagesize=config['imagesize'], 
                                   labelmapping=taskconfig['labelmapping'])
valdataloader = DataLoader(valdataset, 
                             shuffle=False, 
                             pin_memory=True, 
                             num_workers=6, 
                             batch_size=config['batchsize'])

## Create Tensorboard writer to write logs

In [15]:
logdir = config['logdir']

In [16]:
writer = tensorboard.SummaryWriter(config['logdir'])

## Train and Validate

In [17]:
def setfinetune(model: torch.nn.Module, finetune: bool) -> None:
    extractor = taskconfig['extractor']
    if finetune:
        finallayer = model.encoder.fc
        assert isinstance(finallayer, torch.nn.Module)

        #first set required grad to false
        model.encoder.requires_grad_(False)
        for param in model.encoder.fc.parameters():
            param.requires_grad = True
            
    else:
        model.encoder.requires_grad_(True)
    
        
    

In [18]:
def accuracy(outputs: torch.Tensor, labels: torch.Tensor,
            top: Sequence[int] = (1,)) -> Dict[int, float]:
    with torch.no_grad():
        # preds and labels both have shape [N, k]
        _, preds = outputs.topk(k=max(top), dim=1, largest=True, sorted=True)
        labels = labels.view(-1, 1).expand_as(preds)

        corrects = preds.eq(labels).cumsum(dim=1) 
        corrects = corrects.sum(dim=0)  # shape [k]
        
            
        tops = {k: corrects[k - 1].item() for k in top}
        
    return tops

In [19]:
def runepoch(model:torch.nn.Module, 
             dataloader: torch.utils.data.DataLoader,
             finetune: bool,
             device: torch.device,
             lossfn: Optional[torch.nn.Module] = None,
             optimizer: Optional[torch.nn.Module] = None,
             top: Sequence[int] = (1,)) -> Dict[str, float]:
    
    #set to eval when optimizer is none and also when finetune is True
    model.train(optimizer is not None and not finetune)
    
    #define object to track loss
    losses = Averagemeter()
    accuraciestopk = {k: Averagemeter() for k in top}
    
    tqdmloader = tqdm(dataloader)
    with torch.set_grad_enabled(optimizer is not None):
        for batch in tqdmloader:
            images, labels, imagepath = batch

            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            batchsize = labels.size(0)

            description = []
            outputs = model(images)

            if lossfn is not None:
                loss = lossfn(outputs, labels)
                loss = loss.mean()
                losses.update(loss.item(), batchsize)
                description.append(f'Loss {losses.val:.4f} ({losses.avg:.4f})')

            #backward pass
            if optimizer is not None:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            topcorrect = accuracy(outputs, labels, top=top)
            for k, acc in accuraciestopk.items():
                acc.update(topcorrect[k] * (100. / batchsize), n=batchsize)
                description.append(f'accuracy@{k} {acc.val:.3f} ({acc.avg:.3f})')

            tqdmloader.set_description(' '.join(description))
            
    metrics = {}
    metrics['loss'] = losses.avg
    for k, acc in accuraciestopk.items():
        metrics[f'accuracytop{k}'] = acc.avg
        
    return metrics

In [20]:
bestmetrics: Dict[str, float] = {}
top = (1, 3)

In [None]:
for epoch in range(int(config['numepochs']))[startepoch + 1:]:
    print(f'Epoch: {epoch + 1}')
    writer.add_scalar('lr', scheduler.get_last_lr()[0], epoch)
    
    #only finetune the final layers till the epochs less than finetuneepochs later tune entire model
    finetuneepochs = config['finetuneepochs']
    finetune = finetuneepochs > epoch
    setfinetune(model, finetune)
    
    
    print('Training')
    trainmetrics = runepoch(model, traindataloader, finetune, device, lossfn, optimizer, top)
    trainmetrics = prefix_all_keys(trainmetrics, prefix='train/')
    
    
    print('Validation')
    valmetrics = runepoch(model, valdataloader, finetune, device, lossfn, top=top)
    valmetrics = prefix_all_keys(valmetrics, prefix='val/')
    
    
    scheduler.step()
    
    if valmetrics['val/accuracytop1'] > bestmetrics.get('val/accuracytop1', 0):
        filename = os.path.join(logdir, 'ckpt', f'ckpt_{epoch}.pt')
        print(f'New best model! Saving checkpoint to {filename} with accuracy {valmetrics["val/accuracytop1"]}')
        state = {
            'epoch': epoch,
            'model': getattr(model, 'module', model).state_dict(),
            'val/acc': valmetrics['val/accuracytop1'],
            'optimizer': optimizer.state_dict()
        }
        torch.save(state, filename)
        bestmetrics.update(trainmetrics)
        bestmetrics.update(valmetrics)
        bestmetrics['epoch'] = epoch  

  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 2
Training


Loss 3890.6895 (9865.4337) accuracy@1 50.000 (20.453) accuracy@3 100.000 (55.611): 100%|██████████| 2362/2362 [14:40<00:00,  2.68it/s]  
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 46853.0703 (41057.5951) accuracy@1 0.000 (23.908) accuracy@3 100.000 (66.321): 100%|██████████| 1014/1014 [02:16<00:00,  7.42it/s] 


New best model! Saving checkpoint to logs/ckpt/ckpt_1.pt with accuracy 23.90821613619541


  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 3
Training


Loss 0.0000 (8680.9631) accuracy@1 100.000 (20.443) accuracy@3 100.000 (56.108): 100%|██████████| 2362/2362 [18:21<00:00,  2.15it/s]   
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 64359.0156 (44153.7326) accuracy@1 0.000 (14.631) accuracy@3 0.000 (64.890): 100%|██████████| 1014/1014 [02:05<00:00,  8.11it/s]  
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 4
Training


Loss 3350.8125 (9242.9634) accuracy@1 50.000 (19.966) accuracy@3 100.000 (56.362): 100%|██████████| 2362/2362 [17:42<00:00,  2.22it/s] 
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 14421.9688 (17890.1282) accuracy@1 0.000 (26.351) accuracy@3 100.000 (59.437): 100%|██████████| 1014/1014 [02:00<00:00,  8.40it/s] 


New best model! Saving checkpoint to logs/ckpt/ckpt_3.pt with accuracy 26.350851221317544


  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 5
Training


Loss 4644.3203 (9207.3424) accuracy@1 50.000 (19.966) accuracy@3 50.000 (55.632): 100%|██████████| 2362/2362 [17:41<00:00,  2.22it/s]  
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 20233.6250 (38526.8273) accuracy@1 0.000 (23.908) accuracy@3 100.000 (56.773): 100%|██████████| 1014/1014 [02:00<00:00,  8.39it/s] 
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 6
Training


Loss 243.6562 (8523.3321) accuracy@1 50.000 (19.543) accuracy@3 100.000 (55.992): 100%|██████████| 2362/2362 [17:47<00:00,  2.21it/s] 
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 12332684.0000 (7003202.2745) accuracy@1 0.000 (14.631) accuracy@3 100.000 (54.602): 100%|██████████| 1014/1014 [02:02<00:00,  8.31it/s]
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 7
Training


Loss 10310.8418 (8914.2457) accuracy@1 0.000 (19.765) accuracy@3 50.000 (55.696): 100%|██████████| 2362/2362 [17:50<00:00,  2.21it/s]  
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 11635.6680 (8600.9406) accuracy@1 0.000 (2.492) accuracy@3 0.000 (41.031): 100%|██████████| 1014/1014 [02:00<00:00,  8.44it/s] 
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 8
Training


Loss 15120.9141 (7722.3565) accuracy@1 0.000 (19.977) accuracy@3 0.000 (55.791): 100%|██████████| 2362/2362 [17:41<00:00,  2.22it/s]  
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 547542.0000 (500373.8116) accuracy@1 0.000 (15.223) accuracy@3 0.000 (53.787): 100%|██████████| 1014/1014 [02:00<00:00,  8.41it/s]   
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 9
Training


Loss 4250.3281 (7982.3491) accuracy@1 50.000 (19.998) accuracy@3 50.000 (55.579): 100%|██████████| 2362/2362 [17:54<00:00,  2.20it/s]  
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 0.0000 (38385.0362) accuracy@1 100.000 (16.062) accuracy@3 100.000 (44.905): 100%|██████████| 1014/1014 [02:03<00:00,  8.21it/s]  
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 10
Training


Loss 3310.6689 (7303.0904) accuracy@1 0.000 (20.125) accuracy@3 100.000 (56.415): 100%|██████████| 2362/2362 [18:02<00:00,  2.18it/s] 
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 23957.6035 (11632.5557) accuracy@1 0.000 (23.908) accuracy@3 0.000 (66.815): 100%|██████████| 1014/1014 [02:04<00:00,  8.17it/s]  
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 11
Training


Loss 2744.6855 (7261.8274) accuracy@1 50.000 (21.025) accuracy@3 100.000 (57.051): 100%|██████████| 2362/2362 [18:13<00:00,  2.16it/s]
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 7367.0781 (3360.2958) accuracy@1 0.000 (26.351) accuracy@3 0.000 (66.815): 100%|██████████| 1014/1014 [02:01<00:00,  8.33it/s]   
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 12
Training


Loss 12604.7168 (7126.3427) accuracy@1 0.000 (20.718) accuracy@3 0.000 (56.024): 100%|██████████| 2362/2362 [17:54<00:00,  2.20it/s]   
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 0.0000 (7064.9412) accuracy@1 100.000 (16.062) accuracy@3 100.000 (52.134): 100%|██████████| 1014/1014 [02:03<00:00,  8.22it/s]  
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 13
Training


Loss 2794.5195 (7183.9159) accuracy@1 50.000 (20.273) accuracy@3 100.000 (55.791): 100%|██████████| 2362/2362 [17:45<00:00,  2.22it/s]
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 3913609.5000 (869964.8938) accuracy@1 0.000 (26.302) accuracy@3 0.000 (66.815): 100%|██████████| 1014/1014 [02:01<00:00,  8.33it/s]   
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 14
Training


Loss 4335.4951 (6464.0014) accuracy@1 0.000 (20.305) accuracy@3 100.000 (55.907): 100%|██████████| 2362/2362 [18:08<00:00,  2.17it/s] 
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 29961.0625 (42756.8753) accuracy@1 0.000 (15.643) accuracy@3 100.000 (51.962): 100%|██████████| 1014/1014 [02:05<00:00,  8.08it/s]
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 15
Training


Loss 5723.6992 (6165.5745) accuracy@1 0.000 (20.157) accuracy@3 100.000 (56.458): 100%|██████████| 2362/2362 [18:23<00:00,  2.14it/s]  
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 5380.8047 (6352.1809) accuracy@1 0.000 (26.277) accuracy@3 100.000 (57.316): 100%|██████████| 1014/1014 [02:05<00:00,  8.05it/s] 
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 16
Training


Loss 4219.2285 (5847.6983) accuracy@1 50.000 (20.665) accuracy@3 50.000 (56.320): 100%|██████████| 2362/2362 [18:17<00:00,  2.15it/s] 
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 2633.4336 (5703.9366) accuracy@1 0.000 (14.631) accuracy@3 100.000 (57.192): 100%|██████████| 1014/1014 [02:04<00:00,  8.12it/s]
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 17
Training


Loss 5230.3281 (6092.9640) accuracy@1 0.000 (20.443) accuracy@3 0.000 (56.500): 100%|██████████| 2362/2362 [18:24<00:00,  2.14it/s]   
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 9351.9336 (4994.0784) accuracy@1 0.000 (18.727) accuracy@3 0.000 (62.719): 100%|██████████| 1014/1014 [02:03<00:00,  8.23it/s]   
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 18
Training


Loss 3170.4121 (6495.0597) accuracy@1 50.000 (19.659) accuracy@3 50.000 (55.251): 100%|██████████| 2362/2362 [18:25<00:00,  2.14it/s] 
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 24127.1562 (12626.4065) accuracy@1 0.000 (21.737) accuracy@3 0.000 (61.041): 100%|██████████| 1014/1014 [02:04<00:00,  8.18it/s]  
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 19
Training


Loss 1849.3164 (5385.7910) accuracy@1 50.000 (19.574) accuracy@3 100.000 (55.134): 100%|██████████| 2362/2362 [18:24<00:00,  2.14it/s]
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 30715.7383 (28863.4490) accuracy@1 0.000 (5.872) accuracy@3 0.000 (45.398): 100%|██████████| 1014/1014 [02:09<00:00,  7.82it/s]  
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 20
Training


Loss 9054.4512 (5365.3699) accuracy@1 0.000 (20.008) accuracy@3 50.000 (56.246): 100%|██████████| 2362/2362 [18:29<00:00,  2.13it/s]  
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 64380.3320 (14433.4992) accuracy@1 0.000 (14.656) accuracy@3 0.000 (59.882): 100%|██████████| 1014/1014 [02:03<00:00,  8.20it/s]  
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 21
Training


Loss 6809.3770 (4974.4065) accuracy@1 50.000 (19.617) accuracy@3 50.000 (56.257): 100%|██████████| 2362/2362 [18:26<00:00,  2.13it/s] 
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 13417.8750 (6557.7163) accuracy@1 0.000 (24.648) accuracy@3 0.000 (66.075): 100%|██████████| 1014/1014 [02:12<00:00,  7.63it/s]  
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 22
Training


Loss 11761.2363 (5558.8724) accuracy@1 0.000 (20.400) accuracy@3 0.000 (56.669): 100%|██████████| 2362/2362 [21:10<00:00,  1.86it/s]  
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 4214.8242 (8434.2382) accuracy@1 0.000 (20.775) accuracy@3 100.000 (44.041): 100%|██████████| 1014/1014 [02:33<00:00,  6.60it/s] 
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 23
Training


Loss 5168.3750 (5274.2330) accuracy@1 0.000 (19.680) accuracy@3 50.000 (56.077): 100%|██████████| 2362/2362 [22:15<00:00,  1.77it/s]  
  0%|          | 0/1014 [00:00<?, ?it/s]

Validation


Loss 106563.0000 (251106.7240) accuracy@1 0.000 (22.872) accuracy@3 100.000 (55.144): 100%|██████████| 1014/1014 [02:34<00:00,  6.56it/s]
  0%|          | 0/2362 [00:00<?, ?it/s]

Epoch: 24
Training


Loss 4196.8931 (4777.8994) accuracy@1 0.000 (18.944) accuracy@3 75.000 (56.171):  37%|███▋      | 871/2362 [08:16<14:38,  1.70it/s]  