In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import yaml
import os
from collections import OrderedDict
from typing import Sequence, Dict, Optional, Sequence
from tqdm import tqdm
import torch
from torch.optim import Adam, SGD, RMSprop
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.utils import tensorboard

from nn.classification import Classification
from classificationdataset import ClassificationDataset
from utils import *

## Load Task Config

In [3]:
configfile = open('config.yml')
config = yaml.load(configfile, Loader=yaml.FullLoader)

In [4]:
task = config['task']
taskconfig = config['tasks'][task.lower()]

In [5]:
print(f'The task defined is {task} with config {config}')


The task defined is classification with config {'task': 'classification', 'datadir': '/media/ranjan/dl/smartzoo/data', 'numepochs': 100, 'finetuneepochs': 1, 'logdir': 'logs', 'optimizer': {'name': 'rmsprop', 'alpha': 0.9, 'initlr': 0.256, 'momentum': 0.9, 'weightdecay': '1e-5'}, 'lossfn': 'crossentropyloss', 'pretrained': 'imagenet', 'pretrainedpath': '', 'batchsize': 1, 'imagesize': 300, 'tasks': {'classification': {'extractor': 'efficientnet-b3', 'numclasses': 6, 'weights': [], 'labelmapping': {'Lesser Mousedeer': 0, 'Long Tailed Macaque': 1, 'Others': 2, 'Sambar Deer': 3, 'Spotted Whistling Duck': 4, 'Wild Pig': 5}}}}


## Load Model

In [6]:
models = {
    'classification': Classification(extractor=taskconfig['extractor'], numclasses=taskconfig['numclasses'])
}

Loaded pretrained weights for efficientnet-b3


In [7]:
model = models[task]

In [8]:
startepoch = -1
if config['pretrained']  != 'imagenet' and config['pretrainedpath'] !='':
    print(f"Loading the checkpoint from {config['pretrainedpath']}")
    checkpoint = torch.load(config['pretrainedpath'], map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    startepoch = checkpoint['epoch']

In [9]:
cudaavailable = torch.cuda.is_available()
device = torch.device('cuda') if cudaavailable else torch.device('cpu')
model.to(device)

Classification(
  (encoder): EfficientNet(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 40, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
    )
    (_bn0): BatchNorm2d(40, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          40, 40, kernel_size=(3, 3), stride=[1, 1], groups=40, bias=False
          (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        )
        (_bn1): BatchNorm2d(40, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          40, 10, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          10, 40, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
     

## Define loss function

In [10]:
lossfns = {
    'crossentropyloss': torch.nn.CrossEntropyLoss(reduction='none').to(device)
}

In [11]:
lossfn = lossfns[config['lossfn']]

## Define Optimizer

In [12]:
optimizers = {
    'sgd': SGD(model.parameters(), lr=float(config['optimizer']['initlr']), 
               momentum=float(config['optimizer']['momentum']), 
               weight_decay=float(config['optimizer']['momentum'])),
    'rmsprop': RMSprop(model.parameters(), lr=float(config['optimizer']['initlr']), 
                       alpha=float(config['optimizer']['alpha']),
                       momentum=float(config['optimizer']['momentum']), 
                       weight_decay=float(config['optimizer']['weightdecay']))
}

In [13]:
optimizer = optimizers[config['optimizer']['name']]
scheduler = lr_scheduler.StepLR(
            optimizer=optimizer, step_size=1, gamma=0.97 ** (1 / 2.4))

## Load Data

In [14]:
traindataset = ClassificationDataset(datadir=config['datadir'], fold='train', 
                                     imagesize=config['imagesize'], 
                                    labelmapping=taskconfig['labelmapping'])
traindataloader = DataLoader(traindataset, 
                             shuffle=True, 
                             pin_memory=True, 
                             num_workers=6, 
                             batch_size=config['batchsize'])

In [15]:
valdataset = ClassificationDataset(datadir=config['datadir'], fold='valid', 
                                   imagesize=config['imagesize'], 
                                   labelmapping=taskconfig['labelmapping'])
valdataloader = DataLoader(valdataset, 
                             shuffle=False, 
                             pin_memory=True, 
                             num_workers=6, 
                             batch_size=config['batchsize'])

## Create Tensorboard writer to write logs

In [16]:
logdir = config['logdir']

In [17]:
writer = tensorboard.SummaryWriter(config['logdir'])

## Train and Validate

In [18]:
def setfinetune(model: torch.nn.Module, finetune: bool) -> None:
    extractor = taskconfig['extractor']
    if finetune:
        finallayer = model.encoder._fc if 'efficientnet' in extractor else model.encoder.fc
        assert isinstance(finallayer, torch.nn.Module)

        #first set required grad to false
        model.requires_grad_(False)
        for param in finallayer.parameters():
            param.requires_grad = True
            
    else:
        model.encoder.requires_grad_(True)
    
        
    

In [19]:
def accuracy(outputs: torch.Tensor, labels: torch.Tensor,
            top: Sequence[int] = (1,)) -> Dict[int, float]:
    with torch.no_grad():
        # preds and labels both have shape [N, k]
        _, preds = outputs.topk(k=max(top), dim=1, largest=True, sorted=True)
        labels = labels.view(-1, 1).expand_as(preds)

        corrects = preds.eq(labels).cumsum(dim=1) 
        corrects = corrects.sum(dim=0)  # shape [k]
        
            
        tops = {k: corrects[k - 1].item() for k in top}
        
    return tops

In [20]:
def runepoch(model:torch.nn.Module, 
             dataloader: torch.utils.data.DataLoader,
             finetune: bool,
             device: torch.device,
             lossfn: Optional[torch.nn.Module] = None,
             optimizer: Optional[torch.nn.Module] = None,
             top: Sequence[int] = (1,)) -> Dict[str, float]:
    
    #set to eval when optimizer is none and also when finetune is True
    model.train(optimizer is not None and not finetune)
    
    #define object to track loss
    losses = Averagemeter()
    accuraciestopk = {k: Averagemeter() for k in top}
    
    tqdmloader = tqdm(dataloader)
    with torch.set_grad_enabled(optimizer is not None):
        for batch in tqdmloader:
            images, labels, imagepath = batch

            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            batchsize = labels.size(0)

            description = []
            outputs = model(images)

            if lossfn is not None:
                loss = lossfn(outputs, labels)
                loss = loss.mean()
                losses.update(loss.item(), batchsize)
                description.append(f'Loss {losses.val:.4f} ({losses.avg:.4f})')

            #backward pass
            if optimizer is not None:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            topcorrect = accuracy(outputs, labels, top=top)
            for k, acc in accuraciestopk.items():
                acc.update(topcorrect[k] * (100. / batchsize), n=batchsize)
                description.append(f'accuracy@{k} {acc.val:.3f} ({acc.avg:.3f})')

            tqdmloader.set_description(' '.join(description))
            
    metrics = {}
    metrics['loss'] = losses.avg
    for k, acc in accuraciestopk.items():
        metrics[f'accuracytop{k}'] = acc.avg
        
    return metrics

In [21]:
bestmetrics: Dict[str, float] = {}
top = (1, 3)
earlystopping = 0

In [22]:
for epoch in range(int(config['numepochs']))[startepoch + 1:]:
    print(f'Epoch: {epoch + 1}')
    writer.add_scalar('lr', scheduler.get_last_lr()[0], epoch)
    
    #only finetune the final layers till the epochs less than finetuneepochs later tune entire model
    finetuneepochs = config['finetuneepochs']
    finetune = finetuneepochs > epoch
    setfinetune(model, finetune)
    
    
    print('Training')
    trainmetrics = runepoch(model, traindataloader, finetune, device, lossfn, optimizer, top)
    trainmetrics = prefix_all_keys(trainmetrics, prefix='train/')
    
    
    print('Validation')
    valmetrics = runepoch(model, valdataloader, finetune, device, lossfn, top=top)
    valmetrics = prefix_all_keys(valmetrics, prefix='val/')
    
    
    scheduler.step()
    
    if valmetrics['val/accuracytop1'] > bestmetrics.get('val/accuracytop1', 0):
        filename = os.path.join(logdir, 'ckpt', f'ckpt_{epoch}.pt')
        print(f'New best model! Saving checkpoint to {filename} with accuracy {valmetrics["val/accuracytop1"]}')
        state = {
            'epoch': epoch,
            'model': getattr(model, 'module', model).state_dict(),
            'val/acc': valmetrics['val/accuracytop1'],
            'optimizer': optimizer.state_dict()
        }
        torch.save(state, filename)
        bestmetrics.update(trainmetrics)
        bestmetrics.update(valmetrics)
        bestmetrics['epoch'] = epoch  
        earlystopping = 0
        
    else:
        earlystopping += 1
        
    if earlystopping > 9:
        print(f'Model did not improve for 10 epochs so exiting!')

  0%|          | 0/9446 [00:00<?, ?it/s]

Epoch: 1
Training
True


Loss 0.0000 (416.3581) accuracy@1 100.000 (59.062) accuracy@3 100.000 (88.588): 100%|██████████| 9446/9446 [04:00<00:00, 39.24it/s] 
  0%|          | 0/4053 [00:00<?, ?it/s]

Validation
True


Loss 0.0000 (279.2965) accuracy@1 100.000 (74.093) accuracy@3 100.000 (94.843): 100%|██████████| 4053/4053 [02:15<00:00, 29.95it/s] 


New best model! Saving checkpoint to logs/ckpt/ckpt_0.pt with accuracy 74.09326424870466


  0%|          | 0/9446 [00:00<?, ?it/s]

Epoch: 2
Training
False


Loss 6766.0547 (22174.7081) accuracy@1 0.000 (20.083) accuracy@3 100.000 (55.452): 100%|██████████| 9446/9446 [25:37<00:00,  6.15it/s] 
  0%|          | 0/4053 [00:00<?, ?it/s]

Validation
False


Loss 707162.3750 (190160.1122) accuracy@1 0.000 (23.908) accuracy@3 0.000 (52.751): 100%|██████████| 4053/4053 [02:56<00:00, 22.95it/s]  
  0%|          | 0/9446 [00:00<?, ?it/s]

Epoch: 3
Training
False


Loss 28984.3008 (20805.9848) accuracy@1 0.000 (20.474) accuracy@3 100.000 (55.082): 100%|██████████| 9446/9446 [26:34<00:00,  5.92it/s] 
  0%|          | 0/4053 [00:00<?, ?it/s]

Validation
False


Loss 43920.5508 (24964.1793) accuracy@1 0.000 (23.884) accuracy@3 0.000 (59.018): 100%|██████████| 4053/4053 [02:41<00:00, 25.17it/s]  
  0%|          | 0/9446 [00:00<?, ?it/s]

Epoch: 4
Training
False


Loss 0.0000 (20100.4955) accuracy@1 100.000 (19.839) accuracy@3 100.000 (56.172): 100%|██████████| 9446/9446 [28:02<00:00,  5.61it/s]  
  0%|          | 0/4053 [00:00<?, ?it/s]

Validation
False


Loss 7221.7266 (3578.7466) accuracy@1 0.000 (14.631) accuracy@3 0.000 (57.044): 100%|██████████| 4053/4053 [02:50<00:00, 23.81it/s] 
  0%|          | 0/9446 [00:00<?, ?it/s]

Epoch: 5
Training
False


Loss 8188.3594 (21486.5169) accuracy@1 0.000 (20.834) accuracy@3 100.000 (56.680): 100%|██████████| 9446/9446 [28:22<00:00,  5.55it/s] 
  0%|          | 0/4053 [00:00<?, ?it/s]

Validation
False


Loss 39463.6836 (191999.9950) accuracy@1 0.000 (26.326) accuracy@3 0.000 (66.321): 100%|██████████| 4053/4053 [02:40<00:00, 25.29it/s]   
  0%|          | 0/9446 [00:00<?, ?it/s]

Epoch: 6
Training
False


Loss 511.9922 (20459.5487) accuracy@1 0.000 (19.839) accuracy@3 100.000 (55.410): 100%|██████████| 9446/9446 [27:00<00:00,  5.83it/s]  
  0%|          | 0/4053 [00:00<?, ?it/s]

Validation
False


Loss 513252.0312 (83158.0075) accuracy@1 0.000 (8.784) accuracy@3 0.000 (42.635): 100%|██████████| 4053/4053 [02:42<00:00, 24.87it/s] 
  0%|          | 0/9446 [00:00<?, ?it/s]

Epoch: 7
Training
False


Loss 22158.1758 (18980.4102) accuracy@1 0.000 (20.622) accuracy@3 0.000 (56.214): 100%|██████████| 9446/9446 [28:05<00:00,  5.60it/s]  
  0%|          | 0/4053 [00:00<?, ?it/s]

Validation
False


Loss 5389666.0000 (3658805.0763) accuracy@1 0.000 (23.908) accuracy@3 100.000 (55.243): 100%|██████████| 4053/4053 [02:47<00:00, 24.20it/s]
  0%|          | 0/9446 [00:00<?, ?it/s]

Epoch: 8
Training
False


Loss 0.0000 (18529.0553) accuracy@1 100.000 (20.517) accuracy@3 100.000 (55.166): 100%|██████████| 9446/9446 [27:32<00:00,  5.72it/s]  
  0%|          | 0/4053 [00:00<?, ?it/s]

Validation
False


Loss 36404.3828 (21496.9719) accuracy@1 0.000 (16.062) accuracy@3 0.000 (57.044): 100%|██████████| 4053/4053 [02:32<00:00, 26.54it/s]  
  0%|          | 0/9446 [00:00<?, ?it/s]

Epoch: 9
Training
False


Loss 25167.1484 (17948.0890) accuracy@1 0.000 (20.006) accuracy@3 0.000 (55.889):  68%|██████▊   | 6393/9446 [17:54<08:33,  5.95it/s]  


KeyboardInterrupt: 