In [2]:
# train.py
import lightning as L
import torch
from dataset import MNIST3DModule
from model import LitBasicMLP
from config import *
from pytorch_lightning.loggers import WandbLogger

# dataset.py
import lightning as L
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

# model.py
import lightning.pytorch as L
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# config.py
import os

### Define the Hyperparameters

In [112]:
# config

# Training hyperparameters
#INPUT_DIM = (28, 28, 28)
INPUT_DIM = 28*28*28
OUTPUT_DIM = 10
HIDDEN_DIM = 1024
DROPOUT = 0.2

LEARNING_RATE = 3e-5
BATCH_SIZE = 1
MIN_EPOCHS = 1
MAX_EPOCHS = 15

# Dataset
DATA_DIR = "./MNIST"
NUM_WORKERS = 0

# Compute related
ACCELERATOR = "cpu"
DEVICES = 1

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
#DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/3DMLP")


### Define the model

In [113]:
# define basic model

class BasicMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.0):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.linearIn = nn.Linear(self.input_dim, self.hidden_dim)
        self.linearOut = nn.Linear(self.hidden_dim, self.output_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linearIn(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linearOut(x)
        x = self.activation(x)
        x = F.softmax(x, dim=1)
        return x

In [114]:
class LitBasicMLP(L.LightningModule):
    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = BasicMLP(**model_kwargs)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[100, 150], gamma=0.1
        )
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        
        self.log(f"{mode}_loss", loss, prog_bar=True)
        self.log(f"{mode}_acc", acc, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")


### Define the Dataset and DataModule to use

In [115]:
class To3D:

    def __init__(self) -> None:
        #torch._C._log_api_usage_once(self)
        pass

    def __call__(self, x):
        return x.expand((x.shape[0], x.shape[1], x.shape[1], x.shape[1]))

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"

class Flatten:
        
    def __init__(self) -> None:
    #torch._C._log_api_usage_once(self)
        pass

    def __call__(self, x):
        return torch.flatten(x)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"


In [116]:
class MNIST3DModule(L.LightningDataModule):
    def __init__(self, data_dir, batch_size, num_workers, dataset_path):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.dataset_path = dataset_path
        self.dataset = MNIST

    #execute only on 1 GPU
    def prepare_data(self):
        self.dataset(self.data_dir, train=True, download=True)
        self.dataset(self.data_dir, train=False, download=True)
        
    #execute on every GPU
    def setup(self, stage):
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                To3D(),
                Flatten()
            ]
        )
        # For training, we add some augmentation
        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                To3D(),
                Flatten()
            ]
        )
        # Loading the training dataset. We need to split it into a training and validation part
        # We need to do a little trick because the validation set should not use the augmentation.
        train_dataset = self.dataset(
            root=self.dataset_path,
            train=True,
            transform=train_transform,
        )
        val_dataset = self.dataset(
            root=self.dataset_path,
            train=True,
            transform=test_transform,
        )
        L.seed_everything(42)
        self.train_ds, _ = torch.utils.data.random_split(train_dataset, [55000, 5000])
        L.seed_everything(42)
        _, self.val_ds = torch.utils.data.random_split(val_dataset, [55000, 5000])

        self.test_ds = self.dataset(
            root=self.data_dir, train=False, transform=test_transform, download=True
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

### simple tests to make sure everything works

In [117]:
model = LitBasicMLP(
        model_kwargs={
            "input_dim": INPUT_DIM,
            "hidden_dim": HIDDEN_DIM,
            "output_dim": OUTPUT_DIM,
            "dropout": DROPOUT,
        },
        lr=LEARNING_RATE,
    )

dataModule = MNIST3DModule(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    dataset_path=DATA_DIR,
)


In [118]:
dataModule.prepare_data()
dataModule.setup(stage="train")
train_data_loader = dataModule.train_dataloader()

Global seed set to 42
Global seed set to 42


In [119]:
print(train_data_loader.batch_size)

1


In [120]:
original, label = next(iter(train_data_loader))

In [121]:
print(original.shape)

torch.Size([1, 21952])


In [122]:
idx = torch.randint(len(dataModule.train_ds), (1,))
index = torch.LongTensor([2])
example = original.squeeze()
print(example.shape)
# example = example.index_select(0, index)
# print(example.shape)
# example = example.squeeze(0)
# print(example.shape)

torch.Size([21952])


In [123]:
# import matplotlib.pyplot as plt
# plt.imshow(example, cmap="gray")
# plt.show()

In [124]:
print(model.forward(original).shape)

torch.Size([1, 10])


In [125]:
model._calculate_loss((original, label), mode="train")

  rank_zero_warn(


tensor(2.3056, grad_fn=<NllLossBackward0>)