In [35]:
import lightning as L
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split

In [36]:

class DataMod(L.LightningDataModule):
    def __init__(self, data_dir: str, batch_size: int = 32, num_workers: int = 4):
        super().__init__()
        self.mnist_predict = None
        self.mnist_val = None
        self.mnist_test = None
        self.mnist_train = None
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
    
    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
    
    def setup(self, stage: str):
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size, num_workers=self.num_workers)

In [41]:
class ClassifierMod(L.LightningModule):
    def __init__(self, lr = 0.001):
        super().__init__()
        self.model = models.efficientnet_b7()
        self.model.conv_stem = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)
        
        in_features = self.model.conv_stem.in_channels
        self.model.conv_stem = torch.nn.Linear(in_features, 10)  # Assuming 10 classes for MNIST
        self.learning_rate = lr
        
    def forward(self, images):
        return self.model(images)
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        logits = self.forward(images)
        loss = torch.CrossEntropyLoss()(logits, labels)
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits = self.forward(images)
        loss = torch.CrossEntropyLoss()(logits, labels)
        return loss
    
    def test_step(self, batch, batch_idx):
        images, labels = batch
        logits = self.forward(images)
        loss = torch.CrossEntropyLoss()(logits, labels)
        return loss
    
    def predict_step(self, batch, batch_idx):
        images, labels = batch
        predictions = self.forward(images)
        return predictions
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)
        

In [42]:
model = ClassifierMod(0.001)

In [43]:
trainer = L.Trainer(max_epochs=10)
dm = DataMod("./data")
trainer.fit(model, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type         | Params
---------------------------------------
0 | model | EfficientNet | 66.3 M
---------------------------------------
66.3 M    Trainable params
0         Non-trainable params
66.3 M    Total params
265.392   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[32, 1, 28, 28] to have 3 channels, but got 1 channels instead

In [34]:
trainer.predict(model, datamodule=DataMod)

AttributeError: type object 'DataMod' has no attribute 'prepare_data_per_node'