In [2]:
import sys
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score

## Model

In [167]:
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score

class ConvBlock(nn.Module):

    def __init__(self, conv_in_channels: int, conv_out_channels: int, conv_kernel_size: int, conv_stride: int, pool_kernel_size: int) -> None:
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv1d(in_channels=conv_in_channels,
                      out_channels=conv_out_channels,
                      kernel_size=conv_kernel_size,
                      stride=conv_stride),
            nn.ReLU(),
            # nn.MaxPool1d(pool_kernel_size),
            nn.Dropout(0.3),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore
        return self.layers(x)


class PyTorchFCN(torch.nn.Module):
    """A PyTorch implementation of the FCN Baseline
    From https://arxiv.org/abs/1909.04939

    Attributes
    ----------
    sequence_length:
        The size of the input sequence
    num_pred_classes:
        The number of output classes
    """
    
    def __init__(self, in_channels: int = 2, num_pred_classes: int = 2) -> None:
        super(PyTorchFCN, self).__init__()

        self.conv_layers = nn.Sequential(*[
            ConvBlock(in_channels, 1024, 8, 1, 3),
            ConvBlock(1024, 516, 5, 1, 3),
            ConvBlock(516, 256, 3, 1, 3),
        ])
        self.classifier = nn.Sequential(*[
            nn.Linear(9216, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, num_pred_classes)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_layers(x)
        x = self.flatten(x)
        return self.classifier(x)


class LightningModel(L.LightningModule):

    def __init__(self, model=None, hidden_dim=None, num_hidden_layers=None, learning_rate=None, seq_len=None):
        super(LightningModel, self).__init__()

        self.save_hyperparameters()

        self.num_features = 2
        self.num_classes = 2
        self.learning_rate = learning_rate
        self.seq_len = 49

        # model
        if model is None:
            self.model = PyTorchFCN(self.num_features, self.num_classes)

        # metrics
        metrics = MetricCollection([
            MulticlassAccuracy(num_classes=self.num_classes),
            MulticlassPrecision(num_classes=self.num_classes),
            MulticlassRecall(num_classes=self.num_classes),
            MulticlassF1Score(num_classes=self.num_classes)
        ])
        self.train_metrics = metrics.clone(prefix='train_')
        self.val_metrics = metrics.clone(prefix='val_')

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_nb):
        loss, true_labels, logits = self._shared_step(batch)

        self.log('train_loss', loss)
        self.train_metrics(logits, true_labels)
        self.log_dict(self.train_metrics, on_epoch=True, on_step=False)
        return loss

    def training_step_end(self, training_step_outputs):
        return {'loss': training_step_outputs['loss'].sum()}

    def validation_step(self, batch, batch_nb):
        loss, true_labels, logits = self._shared_step(batch)

        self.log('val_loss', loss)
        self.val_metrics(logits, true_labels)
        self.log_dict(self.val_metrics)

        return loss

    def _shared_step(self, batch):
        features, true_labels = batch

        logits = self(features)
        loss = F.cross_entropy(logits, true_labels)
        return loss, true_labels, logits

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.learning_rate)

        scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=self._num_steps())
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}

        return [optimizer], [scheduler]

    def _num_steps(self) -> int:
        """Get number of steps"""
        train_dataloader = self.trainer.datamodule.train_dataloader()
        dataset_size = len(train_dataloader.dataset)
        num_steps = dataset_size * self.trainer.max_epochs // self.trainer.datamodule.batch_size
        return num_steps

## Checkpoint

In [168]:
checkpoint_path = './chechpoints-gpuhub/neural-nappers/psd3mcwl/checkpoints/epoch=4-step=730605.ckpt'

In [169]:
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

In [170]:
model = LightningModel.load_from_checkpoint(checkpoint_path)

In [154]:
state_dict = checkpoint['state_dict']
print(state_dict['model.layers.0.layers.0.weight'].shape)
print(state_dict['model.layers.1.layers.0.weight'].shape)

KeyError: 'model.layers.0.layers.0.weight'

In [93]:
state_dict.keys()

odict_keys(['model.conv_layers.0.layers.0.weight', 'model.conv_layers.0.layers.0.bias', 'model.conv_layers.1.layers.0.weight', 'model.conv_layers.1.layers.0.bias', 'model.classifier.0.weight', 'model.classifier.0.bias', 'model.classifier.2.weight', 'model.classifier.2.bias'])