# Requirements


In [None]:
! pip install pytorch_lightning 
! pip install wandb
! pip install torchvision
! pip install torchinfo
! pip install self-attention-cv

# Imports and Model Building

In [2]:
import torch 
import pytorch_lightning as pl
import torch.cuda
import wandb
import torchvision
import math
from pytorch_lightning.loggers import WandbLogger
from torch import nn
from torch.utils import data
from torchinfo import summary
from self_attention_cv import AxialAttentionBlock

## Model Classes

In [83]:
class Conv2dBlock(nn.Module):

    def __init__(
            self, in_channels, out_channels, kernel_size,
            dilation=1, dropout=0.0, pool_size=1,
            activations = nn.Mish
    ):
        super().__init__()
        padding = dilation * (kernel_size - 1) // 2  # padding needed to maintain size

        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size, 
                      padding=padding, dilation=dilation),
            nn.BatchNorm2d(out_channels, momentum=0.0735)
        ]
        if pool_size > 1:
            layers.append(nn.MaxPool2d(kernel_size=pool_size))
        if dropout > 0.0:
            layers.append(nn.Dropout(p=dropout))
        layers.append(activations())

        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class TransposeConv2dBlock(nn.Module):

    def __init__(
            self, in_channels, out_channels, kernel_size, dropout=0.0,
            activations = nn.Mish
    ):
        super().__init__()

        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size),
            nn.BatchNorm2d(out_channels, momentum=0.0735)
        ]
        if dropout > 0.0:
            layers.append(nn.Dropout(p=dropout))
        layers.append(activations())

        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class DilatedResConv2dBlock(nn.Module):

    def __init__(
            self, in_channels, mid_channels, out_channels, kernel_size,
            dilation=1, dropout=0.0, activations = nn.Mish
    ):
        super().__init__()

        self.blocks = nn.Sequential(
            Conv2dBlock(in_channels, mid_channels, kernel_size, dilation=dilation),
            Conv2dBlock(mid_channels, out_channels, kernel_size, dropout=dropout)
        )

        self.activation = activations()

    def forward(self, x):
        blocks_output = self.activation(self.blocks(x))
        print(blocks_output.shape, x.shape)
        x = x + blocks_output   # residual connection
        return x

In [132]:
class ClassifierHead(nn.Module):

    def __init__(self, n_layer, n_head, feature_map_dim, input_channels=96, mid_channels = 24, output_channels=12, 
                 kernel_size=6, max_pool_kernel_size = 2, conv_layers=6, hidden_size=1024):
        super().__init__()

        axial_attn_block = AxialAttentionBlock(
            input_channels, feature_map_dim, n_head)

        # build up the modules
        modules = [axial_attn_block for _ in range(n_layer)]
        for i in range(conv_layers):
            modules += [Conv2dBlock(input_channels,
                                    mid_channels, kernel_size=kernel_size)]
            modules += [nn.BatchNorm2d(mid_channels, momentum=0.0735)]
            modules += [Conv2dBlock(mid_channels,
                                    input_channels, kernel_size=kernel_size)]
            modules += [nn.BatchNorm2d(input_channels, momentum=0.0735)]
            # add maxpooling after every other conv layer
            if i % 2 == 0: 
                modules += [nn.MaxPool2d(kernel_size=max_pool_kernel_size)]
        modules += [Conv2dBlock(input_channels, output_channels, kernel_size=1)]
        # unpack modules
        self.trunk = nn.Sequential(*modules)

    def forward(self, x):
        return self.trunk(x)


class ClassifierTrunk(nn.Module):

    def __init__(self, input_channels=1024, mid_channels=64, out_channels=96, kernel_size=3, transpose_kernel=12, dropout=0.04):
        super().__init__()

        # build up the modules
        modules = [Conv2dBlock(
            input_channels, out_channels, kernel_size=kernel_size)]
        dilation = 1.0
        for _ in range(6):
            layer_dilation = round(dilation)
            modules += [TransposeConv2dBlock(out_channels,
                                             out_channels, transpose_kernel, dropout)]
            modules += [DilatedResConv2dBlock(
                out_channels, mid_channels, out_channels, kernel_size, layer_dilation, dropout)]
            dilation *= 1.75

        # unpack modules
        self.head = nn.Sequential(*modules)

    def forward(self, input_embeds):
        return self.head(input_embeds)


class XRayPredictor(nn.Module):

    def __init__(self, n_layer, n_head, n_inner, dropout):
        super().__init__()
        self.trunk = ClassifierTrunk()
        self.head = ClassifierHead(n_layer, n_head, n_inner, dropout)
        # TODO: figure out the output dimensions of the trunk
        # make same dims as head output for residual.
        self.trunk_residual = nn.Linear()
        self.residual_act = nn.Mish()
        self.fc_out = nn.Linear(48, 15)

    def forward(self, input_embeds):
        z = self.trunk(input_embeds)
        y = self.head(z)
        y = self.residual_act(y + self.trunk_residual(z))
        return self.fc_out(y)


class LitContactPredictor(pl.LightningModule):

    def __init__(
            self,
            n_layer, n_head, n_inner, dropout,
            augment_rc, augment_shift, lr,
            **kwargs
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["model"])

        self.model = XRayPredictor(n_layer, n_head, n_inner, dropout)
        self.lr = lr

    def forward(self, input_seqs):
        return self.model(input_seqs, flatten=True)

    def training_step(self, batch, batch_idx):
        batch = self._stochastic_augment(batch)
        loss, batch_size = self._process_batch(batch)
        self.log('train_loss', loss, batch_size=batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, batch_size = self._process_batch(batch)
        self.log('val_loss', loss, batch_size=batch_size)
        return loss

    def test_step(self, batch, batch_idx):
        loss, batch_size = self._process_batch(batch, test=True)
        self.log('test_loss', loss, batch_size=batch_size)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(), lr=self.lr, momentum=0.975)
        return optimizer


## DataModule stuff 

In [5]:
import pathlib
import numpy as np
import pytorch_lightning as pl
from torchvision.datasets import DatasetFolder

class LungDetectionDataModule(pl.LightningDataModule):

    def __init__(self, batch_size=2, num_workers=0, master_path=""):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers

        data_dir = master_path
        self.train = DatasetFolder(master_path + 'embeddingtrain2', 
                                   loader=torch.load, extensions=('.tensor'))
        self.valid = DatasetFolder(master_path + 'embeddingval2', 
                                   loader=torch.load, extensions=('.tensor'))
        self.test = DatasetFolder(master_path + 'embeddingval2', 
                                  loader=torch.load, extensions=('.tensor'))

    def train_dataloader(self):
        return data.DataLoader(
            self.train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers
        )

    def val_dataloader(self):
        return data.DataLoader(
            self.valid,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )

    def test_dataloader(self):
        return data.DataLoader(
            self.test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )


In [6]:
def train_main(batch_size=128, num_workers = 4, max_epochs = 20, 
               master_path="", n_layer = 5, n_head = 5, n_inner = 64,  dropout = 0.01):
    # seed experiment
    pl.seed_everything(seed=123)

    # construct datamodule
    datamodule = LungDetectionDataModule(batch_size=batch_size, 
                                         num_workers=num_workers, 
                                         master_path = "")
    data_size = len(datamodule.train)

    # construct model
    lit_model = LitContactPredictor(seed=123, batch_size=batch_size, 
                                    num_workers=num_workers, 
                                    data_size=data_size,
                                    n_layer= n_layer,
                                    n_head = n_head,
                                    n_inner = n_inner,
                                    dropout=dropout
                                    )

    # logging
    save_dir = pathlib.Path(__file__).parents[2]
    logger = WandbLogger(project="train_xray", log_model="all", save_dir=str(save_dir))
    logger.experiment.config["train_set_len"] = len(datamodule.train)
    logger.experiment.config["val_set_len"] = len(datamodule.valid)
    logger.experiment.config["batch_size"] = batch_size

    # callbacks
    early_stopping = pl.callbacks.EarlyStopping(monitor="val_accuracy", patience=40)
    checkpointing = pl.callbacks.ModelCheckpoint(monitor="val_accuracy", mode="min", save_top_k=20)
    stochastic_weighting = pl.callbacks.StochasticWeightAveraging(swa_epoch_start=0.75, 
                                                                  annealing_epochs=5, 
                                                                  swa_lrs=4.5e-4)
    lr_monitor = pl.callbacks.LearningRateMonitor("step", True)

    # training
    trainer = pl.Trainer(
        callbacks=[early_stopping, checkpointing, stochastic_weighting, lr_monitor],
        deterministic=True,
        gpus=-1,
        gradient_clip_val=15,
        logger=logger,
        log_every_n_steps=1,
        enable_progress_bar=True,
        max_epochs=max_epochs,
    )

    trainer.fit(lit_model, datamodule=datamodule)

    wandb.finish()

    return lit_model

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount = True)

Mounted at /content/drive


# Training the model

In [7]:
#load data
#location on Google Drive
master_path = '../data_processing/embeddings/'

In [8]:
train_configs = {"n_layer": 2, 
                 "n_head": 4,
                 "n_inner": 64,
                 "dropout": 0.01}

## Sanity checks

In [19]:
test_datamodule = LungDetectionDataModule(master_path=master_path, batch_size = 128)

In [20]:
test_train_loader = test_datamodule.train_dataloader()

In [22]:
test_train_sample = next(iter(test_train_loader))
test_input, test_label = test_train_sample
test_input, test_label = test_input.cuda(), test_label.cuda()
test_input.shape, test_label.shape

(torch.Size([128, 1024, 7, 7]), torch.Size([128]))

In [23]:
# test_label is a bit werid
test_label

tensor([10,  5,  8, 11, 10, 14,  0, 13, 10, 10,  0, 14,  9, 10, 14, 12, 13, 14,
        13,  9,  3, 11,  8,  9,  2,  0,  4,  0,  8,  4,  8, 10, 14,  9, 10,  8,
         8, 10,  4, 10,  2, 11, 13,  2,  4,  8, 10,  8,  4,  8,  4, 11,  8,  0,
        13,  8, 12, 10, 13, 12, 10,  8,  0,  8,  4, 10,  5, 10,  8, 10, 14, 10,
        10,  8, 10, 13, 10,  6, 10, 13, 14, 14,  1, 10,  1, 14, 10, 14, 14,  1,
        13,  8,  4, 11,  0,  4,  8,  0, 10, 10, 14,  8,  8,  1, 11,  0,  1,  8,
         0,  9,  8, 11,  8, 10,  4,  4, 10, 10,  8, 14, 11,  8, 14,  8, 10, 13,
         8, 12], device='cuda:0')

In [85]:
n_layer, n_head, dropout = train_configs["n_layer"], train_configs["n_head"], train_configs["dropout"]

In [101]:
input_channels, mid_channels, out_channels, kernel_size, transpose_kernel, dropout_trunk = 1024, 64, 48, 5, 5, 0.04
dryrun_trunk = ClassifierTrunk(input_channels, mid_channels, out_channels, kernel_size, transpose_kernel, dropout_trunk)

In [102]:
summary(dryrun_trunk, input_data = test_input, device = "cuda")

torch.Size([128, 48, 11, 11]) torch.Size([128, 48, 11, 11])
torch.Size([128, 48, 15, 15]) torch.Size([128, 48, 15, 15])
torch.Size([128, 48, 19, 19]) torch.Size([128, 48, 19, 19])
torch.Size([128, 48, 23, 23]) torch.Size([128, 48, 23, 23])
torch.Size([128, 48, 27, 27]) torch.Size([128, 48, 27, 27])
torch.Size([128, 48, 31, 31]) torch.Size([128, 48, 31, 31])


Layer (type:depth-idx)                             Output Shape              Param #
ClassifierTrunk                                    [128, 48, 31, 31]         --
├─Sequential: 1-1                                  [128, 48, 31, 31]         --
│    └─Conv2dBlock: 2-1                            [128, 48, 7, 7]           --
│    │    └─Sequential: 3-1                        [128, 48, 7, 7]           1,228,944
│    └─TransposeConv2dBlock: 2-2                   [128, 48, 11, 11]         --
│    │    └─Sequential: 3-2                        [128, 48, 11, 11]         57,744
│    └─DilatedResConv2dBlock: 2-3                  [128, 48, 11, 11]         --
│    │    └─Sequential: 3-3                        [128, 48, 11, 11]         153,936
│    │    └─Mish: 3-4                              [128, 48, 11, 11]         --
│    └─TransposeConv2dBlock: 2-4                   [128, 48, 15, 15]         --
│    │    └─Sequential: 3-5                        [128, 48, 15, 15]         57,744
│    └─DilatedR

In [103]:
dryrun_trunk_output = dryrun_trunk(test_input)
dryrun_trunk_output.shape

torch.Size([128, 48, 11, 11]) torch.Size([128, 48, 11, 11])
torch.Size([128, 48, 15, 15]) torch.Size([128, 48, 15, 15])
torch.Size([128, 48, 19, 19]) torch.Size([128, 48, 19, 19])
torch.Size([128, 48, 23, 23]) torch.Size([128, 48, 23, 23])
torch.Size([128, 48, 27, 27]) torch.Size([128, 48, 27, 27])
torch.Size([128, 48, 31, 31]) torch.Size([128, 48, 31, 31])


torch.Size([128, 48, 31, 31])

In [147]:
dryrun_head = ClassifierHead(3, 8, dryrun_trunk_output.size(2), 
                             input_channels = out_channels, 
                             mid_channels= 24,  
                             output_channels = 12, 
                             kernel_size = 4, 
                             max_pool_kernel_size = 2, 
                             conv_layers = 2, 
                             hidden_size = 512)

In [148]:
summary(dryrun_head, input_data = dryrun_trunk_output, verbose=1)

Layer (type:depth-idx)                             Output Shape              Param #
ClassifierHead                                     [128, 12, 12, 12]         --
├─Sequential: 1-1                                  [128, 12, 12, 12]         --
│    └─AxialAttentionBlock: 2-1                    [128, 48, 31, 31]         --
│    │    └─Sequential: 3-1                        [128, 128, 31, 31]        6,400
│    │    └─ReLU: 3-2                              [128, 128, 31, 31]        --
│    │    └─AxialAttention: 3-3                    [3968, 128, 31]           35,792
│    │    └─AxialAttention: 3-4                    [3968, 128, 31]           35,792
│    │    └─ReLU: 3-5                              [3968, 128, 31]           --
│    │    └─Sequential: 3-6                        [128, 48, 31, 31]         6,240
│    │    └─ReLU: 3-7                              [128, 48, 31, 31]         --
│    └─AxialAttentionBlock: 2-2                    [128, 48, 31, 31]         (recursive)
│    │    └─

Layer (type:depth-idx)                             Output Shape              Param #
ClassifierHead                                     [128, 12, 12, 12]         --
├─Sequential: 1-1                                  [128, 12, 12, 12]         --
│    └─AxialAttentionBlock: 2-1                    [128, 48, 31, 31]         --
│    │    └─Sequential: 3-1                        [128, 128, 31, 31]        6,400
│    │    └─ReLU: 3-2                              [128, 128, 31, 31]        --
│    │    └─AxialAttention: 3-3                    [3968, 128, 31]           35,792
│    │    └─AxialAttention: 3-4                    [3968, 128, 31]           35,792
│    │    └─ReLU: 3-5                              [3968, 128, 31]           --
│    │    └─Sequential: 3-6                        [128, 48, 31, 31]         6,240
│    │    └─ReLU: 3-7                              [128, 48, 31, 31]         --
│    └─AxialAttentionBlock: 2-2                    [128, 48, 31, 31]         (recursive)
│    │    └─

In [None]:
# model dryrun 
dryrun_model = LitContactPredictor(seed=123, batch_size=2, 
                                    num_workers=4, data_size=2, 
                                   n_layer = 5, n_head = 5, 
                                   n_inner = 64,  dropout = 0.01)


## Model training

In [None]:
model = train_main(master_path = master_path)