In [1]:
# Import required libraries
import torch
import pytorch_lightning as pl
import numpy as np
import albumentations as A
imoprt matplotlib.pyplot as plt
import pandas as pd
import wandb
import glob
import timm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Get data
# Split data into train and test
categories = {
    "no": 0,
    "sphere": 1,
    "vort": 2,
}

dataset_path = ""
train_dataset_path = dataset_path + "train/"
val_dataset_path = dataset_path + "val/"

# Get file paths as a list
train_files = glob.glob(train_dataset_path + "*/*.npy")
val_files = glob.glob(val_dataset_path + "*/*.npy")

In [None]:
# Dataset object
# Use albumentations for data augmentation
class SubstructureDataset(torch.utils.data.Dataset):
    def __init__(self, data_files, transform=None):
        self.data_files = data_files
        self.transform = transform
    def __len__(self):
        return len(self.data_files)
    def __getitem__(self, idx):
        # open npy file
        # apply transform
        # return image and label
        image = np.load(self.data_files[idx])
        label = categories[self.data_files[idx].split("/")[-2]]
        if self.transform:
            image = self.transform(image=image)["image"]
        # Convert to tensor
        image = torch.tensor(image, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)
        return image, label

In [None]:
# Define pytorch model
# Use timm for pretrained models
class Model(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = timm.create_model("efficientnet_b4", pretrained=True)
        self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
    def forward(self, x):
        return self.model(x)


In [None]:
# Define pl module
class LitModel(pl.LightningModule):
    def __init__(self, num_classes, lr=1e-3):
        super().__init__()
        self.model = Model(num_classes)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.lr = lr
    def forward(self, x):
        return self.model(x)
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss)
        return loss
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log("val_loss", loss)
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [None]:
# Define dataloaders
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
])

train_dataset = SubstructureDataset(train_files, transform=train_transform)
val_dataset = SubstructureDataset(val_files)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
# Train model
model = LitModel(num_classes=3)
trainer = pl.Trainer(gpus=1, max_epochs=10)
trainer.fit(model, train_dataloader, val_dataloader)