<a href="https://www.kaggle.com/code/chhelp/lb-99-446-lightning-resnet18-digit-recognizer?scriptVersionId=149267374" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/digit-recognizer/sample_submission.csv
/kaggle/input/digit-recognizer/train.csv
/kaggle/input/digit-recognizer/test.csv


In [2]:
!pip install pytorch-lightning



In [3]:
import os
import timm
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
from torch import optim, nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torchvision import transforms as T
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, confusion_matrix
from torchmetrics import Accuracy

import warnings
warnings.filterwarnings('ignore')



In [4]:
path = '/kaggle/input/digit-recognizer'
sample_submission = pd.read_csv(f'{path}/sample_submission.csv')
train = pd.read_csv(f'{path}/train.csv')
test = pd.read_csv(f'{path}/test.csv')

In [5]:
sample_submission.head()

Unnamed: 0,ImageId,Label
0,1,0
1,2,0
2,3,0
3,4,0
4,5,0


In [6]:
test.head()

Unnamed: 0,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,pixel9,...,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [7]:
train.head()

Unnamed: 0,label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,...,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783
0,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,4,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [8]:
class DigitDataset(Dataset):
    def __init__(self, df, train=True, transform=None):
        self.train = train
        self.transform = transform

        if train:
            self.X = df.drop(['label'], axis=1).values.astype('float32')
            self.label = df['label'].values.tolist()
        else:
            self.X = df.values.astype('float32')
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        img = self.X[idx].reshape(28, 28)
        if self.transform:
            img = self.transform(img)

        if not self.train:
            return img
        return img, self.label[idx]

In [9]:
class DigitModule(pl.LightningModule):
    def __init__(self, model_name='resnet18', pretrained=True, num_classes=10):
        super().__init__()
        model = timm.create_model(
            model_name,
            pretrained=pretrained,
            in_chans=1,
            num_classes=num_classes)
        """
        for name, param in model.named_parameters():
            param.requires_grad = False
        m = list(model.children())
        layers = m[:len(m):len(m)-1]
        for layer in layers:
            for param in layer.parameters():
                param.requires_grad = True
        """
        self.model = model

        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task='multiclass', num_classes=num_classes)
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def common(self, batch, batch_idx):
        x, y = batch
        scores = self(x)
        loss = self.loss_fn(scores, y)
        return loss, scores, y
    
    def training_step(self, batch, batch_idx):
        loss, scores, y = self.common(batch, batch_idx)
        self.log('train_loss', loss)
        return loss
        
    def validation_step(self, batch, batch_idx):
        loss, scores, y = self.common(batch, batch_idx)
        y_pred = torch.argmax(scores, axis=1)
        accuracy = self.accuracy(y_pred, y)
        self.log_dict({
            'validation_loss': loss,
            'accuracy': accuracy,
        })
        return loss
    
    def test_step(self, batch, batch_idx):
        loss, scores, y = self.common(batch, batch_idx)
        self.log('test_loss', loss)
        return loss
    
    def predict_step(self, batch, batch_idx):
        x = batch
        scores = self(x)
        return scores
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0.001)
        return [optimizer], [scheduler]

In [10]:
class DigitDataModule(pl.LightningDataModule):
    def __init__(self, train_ds=None, val_ds=None, predict_ds=None, batch_size=32, num_workers=4):
        super().__init__()
        self.train_ds = train_ds
        self.val_ds = val_ds
        self.predict_ds = predict_ds
        self.batch_size = batch_size   
        self.num_workers = num_workers   
                                             
    def prepare_data(self):
        pass
                                             
    def setup(self, stage):
        pass
                                                                                           
    def train_dataloader(self):                                                            
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )
                                             
    def val_dataloader(self):
        return DataLoader(                                                                 
            self.val_ds,                                                                   
            batch_size=self.batch_size,                                                                                                                                               
            num_workers=self.num_workers,
            shuffle=False, 
            pin_memory=True,
        )
    
    def predict_dataloader(self):
        return DataLoader(
            self.predict_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )

In [11]:
batch_size = 128 * 32
num_fold = 6
max_epochs = 20
num_label = train['label'].nunique()
model_name = 'resnet18'

In [12]:
skf = StratifiedKFold(n_splits=num_fold, shuffle=True, random_state=42)

In [13]:
OUTPUT_DIR = '.'

In [14]:
transform_train = T.Compose([
    T.ToTensor(),
    T.RandomAffine(degrees=(-5, 5), translate=(0.025, 0.025), scale=(0.975, 1.025)),

    #T.Normalize((0.5, ), (0.5, )),
    T.Normalize((0.1307,), (0.3081,)),
])
transform_test = T.Compose([
    T.ToTensor(),
    T.Normalize((0.1307,), (0.3081,)),
])

preds = []
for fold, (train_idx, test_idx) in enumerate(skf.split(train, train['label'])):
    train_ds = DigitDataset(train.iloc[train_idx], train=True, transform=transform_train)
    val_ds = DigitDataset(train.iloc[test_idx], train=True, transform=transform_test)
    test_ds = DigitDataset(test, train=False, transform=transform_test)
    model = DigitModule(model_name=model_name, pretrained=True, num_classes=num_label)
    dm = DigitDataModule(
        train_ds=train_ds,
        val_ds=val_ds,
        predict_ds=test_ds,
        batch_size=batch_size,
        num_workers=4
    )
    auc_checkpoint = ModelCheckpoint(
        dirpath=OUTPUT_DIR,
        filename=f'{model_name}_{fold}',
        monitor='accuracy',
        save_top_k=1,
        save_weights_only=True,
        mode='max',
    )
    trainer = pl.Trainer(
        #accelerator='gpu',
        callbacks=[auc_checkpoint],
        default_root_dir=OUTPUT_DIR,
        devices=1,
        min_epochs=1,
        max_epochs=max_epochs,
        precision=16,
    )
    
    trainer.fit(model, dm)
    trainer.validate(model, dm)
    pred = trainer.predict(model, dm)
    preds.append(pred)

Downloading model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Predicting: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Predicting: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Predicting: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Predicting: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Predicting: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Predicting: 0it [00:00, ?it/s]

In [15]:
len(preds), len(preds[0]), preds[0][0].shape

(6, 7, torch.Size([4096, 10]))

In [16]:
y_pred = torch.stack([torch.concat(pred) for pred in preds])

In [17]:
y_pred = y_pred.mean(dim=0).argmax(dim=1)

In [18]:
sample_submission['Label'] = y_pred
sample_submission.to_csv('submission.csv',index=False)

In [19]:
sample_submission.head()

Unnamed: 0,ImageId,Label
0,1,2
1,2,0
2,3,9
3,4,0
4,5,3
