In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

In [None]:
NUM_CARDS = 3
NUM_DETECTORS = 60
NUM_TIMESTEPS = 512

# Dataset

In [None]:
class GeoDataset(Dataset):
    
    def __init__(self, num_cards, num_detectors, num_timesteps):
        self.num_cards = num_cards
        self.num_detectors = num_detectors
        self.num_timesteps = num_timesteps
    
    def __getitem__(self, idx):
        return torch.rand(self.num_cards, self.num_detectors, self.num_timesteps), torch.rand(self.num_detectors, 1)
    
    def __len__(self):
        return 1000

In [None]:
dataset = GeoDataset(NUM_CARDS, NUM_DETECTORS, NUM_TIMESTEPS)
dataloader = DataLoader(dataset, batch_size=10)

In [None]:
next(iter(dataloader))[1].shape

# Model

In [None]:
class ResNeXtBlock(nn.Module):
    
    def __init__(self, num_channels, norm_groups=32, expansion_rate=4):
        super().__init__()
        self.dw_conv = nn.Conv2d(num_channels, num_channels, kernel_size=7, padding=3, groups=num_channels)
        self.group_norm = nn.GroupNorm(norm_groups, num_channels)
        hidden_channels = expansion_rate * num_channels
        self.feed_forward = nn.Sequential(
            nn.Conv2d(num_channels, hidden_channels, kernel_size=1),
            nn.SiLU(),
            nn.Conv2d(hidden_channels, num_channels, kernel_size=1)
        )
        
    def forward(self, x):
        out = self.dw_conv(x)
        out = self.group_norm(out)
        out = self.feed_forward(out)
        x = x + out
        return x

In [None]:
class ResNet1D(nn.Module):
    
    def __init__(
        self,
        model_channels=256,
        num_channels=3,
        groups=32,
        expansion_rate=4,
        dim_mult=(1, 2, 4, 8),
        num_blocks=(3, 3, 3, 3),
    ):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(num_channels, model_channels, kernel_size=(1, 4), stride=(1, 4)),
            nn.GroupNorm(groups, model_channels)
        )
        
        hidden_dims = list(map(lambda mult: model_channels * mult, (1,) + dim_mult))
        in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
        self.resnext_blocks = nn.Sequential(*[
            nn.Sequential(
                *[ResNeXtBlock(in_dim, groups, expansion_rate) for _ in range(num_block)],
                nn.GroupNorm(groups, in_dim),
                nn.Conv2d(in_dim, out_dim, kernel_size=(1, 2), stride=(1, 2))
            ) for (in_dim, out_dim), num_block in zip(in_out_dims, num_blocks)
        ])
        
        self.out_layer = nn.Linear(in_out_dims[-1][-1], 1)
        
    def forward(self, x):
        x = self.stem(x)
        x = self.resnext_blocks(x)
        x = x.mean(-1).transpose(-1, -2)
        x = self.out_layer(x)
        return x

# Trainer

In [None]:
class GeoTrainer(pl.LightningModule):

    def __init__(self, num_channels):
        super().__init__()
        self.model = ResNet1D(num_channels=NUM_CARDS)
        self.loss = nn.MSELoss()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=2e-4)
#         lr_scheduler = torch.optim.lr_scheduler.LinearLR(
#             optimizer, start_factor=0.0002, end_factor=1.0, total_iters=5000
#         )
        return [optimizer]
    
    def model_step(self, batch, stage):
        img, target_timesteps = batch
        pred_timesteps = self.model(img)
        loss = self.loss(pred_timesteps, target_timesteps)
        self.log(f'{stage}_loss', loss.detach().cpu().item())
        return loss

    def training_step(self, batch, batch_idx):
        return self.model_step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self.model_step(batch, 'valid')

In [None]:
tb_logger = pl.loggers.TensorBoardLogger(name='test', save_dir="./tb_logs", default_hp_metric=False)
callbacks = [
    pl.callbacks.ModelCheckpoint(
        dirpath="./saved_models/test", filename="{step}", monitor="train_loss", mode="min",
        save_top_k=-1, every_n_train_steps=5000
    )
]
trainer = pl.Trainer(
    logger=tb_logger,
#     callbacks=callbacks,
    gpus=1,
    log_every_n_steps=5,
    max_steps=500000,
    gradient_clip_val=1.0,
    gradient_clip_algorithm="value"
)
model = GeoTrainer(NUM_CARDS)

In [None]:
trainer.fit(model, dataloader)