# Trains and evaluate the 2020 extended classifier

In [None]:
from deepcell.datasets.visual_behavior_extended_dataset import VisualBehaviorExtendedDataset
from deepcell.trainer import Trainer
from deepcell.inference import cv_performance
from deepcell.models.classifier import Classifier
from deepcell.data_splitter import DataSplitter
from deepcell.transform import Transform

import numpy as np
import torch
import torchvision
from torchvision import transforms
from imgaug import augmenters as iaa
import random
import matplotlib.pyplot as plt
from pathlib import Path

random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)

In [None]:
!git rev-parse --short HEAD

In [None]:
ARTIFACT_DESTINATION = Path('/tmp/artifacts')
CHECKPOINT_PATH = Path('.').absolute().parent.parent / 'saved_models' / '022122_lr1e-4_wd_0_linear_classifier_upto_layer_22_freeze_upto_8_dropout_0'

In [None]:
dataset = VisualBehaviorExtendedDataset(artifact_destination=ARTIFACT_DESTINATION, 
                                        exclude_projects=[
                                            'ophys-experts-go-big-or-go-home',
                                            'ophys-experts-slc-oct-2020_ophys-experts-go-big-or-go-home', 
                                            'ophys-expert-danielsf-additions'])

In [None]:
len(dataset.dataset)

In [None]:
dataset.project_meta.groupby('project_name')['date'].agg(['min', 'max'])

In [None]:
dataset.project_meta.groupby('project_name').size()

In [None]:
all_transform = transforms.Compose([
    iaa.Sequential([
        iaa.Affine(
            rotate=[0, 90, 180, 270, -90, -180, -270], order=0
        ),
        iaa.Fliplr(0.5),
        iaa.Flipud(0.5),
        iaa.CenterCropToFixedSize(height=128, width=128),
    ]).augment_image,
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_transform = Transform(all_transform=all_transform)

all_transform = transforms.Compose([
    iaa.Sequential([
        iaa.CenterCropToFixedSize(height=128, width=128)
    ]).augment_image,
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = Transform(all_transform=all_transform)

data_splitter = DataSplitter(model_inputs=dataset.dataset, train_transform=train_transform,
                             test_transform=test_transform, seed=1234, image_dim=(128, 128), 
                             use_correlation_projection=True)
train, test = data_splitter.get_train_test_split(test_size=.3)

print(len(train) + len(test))
print(train.y.mean())
print(len(test))

In [None]:
model = torchvision.models.vgg11_bn(pretrained=True, progress=False)
model

In [None]:
model = torchvision.models.vgg11_bn(pretrained=True, progress=False)
model = Classifier(model=model, truncate_to_layer=22, freeze_up_to_layer=8,
                  classifier_cfg=[], dropout_prob=0.0,
                  final_activation_map_spatial_dimensions=(1, 1))
model

In [None]:
optimizer = lambda: torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0)

criterion = torch.nn.BCEWithLogitsLoss()
scheduler = lambda optimizer: torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=15, verbose=True, factor=.5)
trainer = Trainer(model=model, n_epochs=1000, optimizer=optimizer, scheduler=scheduler,
                        criterion=criterion, save_path=f'{CHECKPOINT_PATH}',
                        early_stopping=30)
cv_metrics = trainer.cross_validate(train_dataset=train, data_splitter=data_splitter, batch_size=64)

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=1, figsize=(5, 20))

for i in range(5):
    x = torch.load(f'{CHECKPOINT_PATH}/{i}_model.pt')
    ax[i].plot(x['performance']['train']['f1s'], label='train')
    ax[i].plot(x['performance']['val']['f1s'], label='val')
    ax[i].legend()
    ax[i].set_xlabel('Epoch')
    ax[i].set_ylabel('F1')

plt.show()

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=1, figsize=(5, 20))

for i in range(5):
    x = torch.load(f'{CHECKPOINT_PATH}/{i}_model.pt')
    ax[i].plot(x['performance']['train']['losses'], label='train')
    ax[i].plot(x['performance']['val']['losses'], label='val')
    ax[i].legend()
    ax[i].set_xlabel('Epoch')
    ax[i].set_ylabel('loss')

plt.show()

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=1, figsize=(5, 10))

for i in range(5):
    ax[i].plot(cv_metrics.train_metrics[i].auprs[:cv_metrics.train_metrics[i].best_epoch + 60], label='train')
    ax[i].plot(cv_metrics.valid_metrics[i].auprs[:cv_metrics.train_metrics[i].best_epoch + 60], label='val')
    ax[i].legend()
    ax[i].set_xlabel('Epoch')
    ax[i].set_ylabel('AUPR')

plt.show()

In [None]:
model = torchvision.models.vgg11_bn(pretrained=True, progress=False)
model = Classifier(model=model, truncate_to_layer=15, classifier_cfg=[1024, 1024])
preds, metrics = cv_performance(model=model, data_splitter=data_splitter,
                            train=train, checkpoint_path=f'{CHECKPOINT_PATH}')
metrics