In [None]:
import torch
import torchvision
from torch import nn,optim
from torch.utils.data import Dataset,DataLoader,random_split
from torchvision.transforms import Compose,GaussianBlur,RandomAutocontrast
from torchvision.models import resnet18,ResNet18_Weights
import os
import lightning as L
from lightning.pytorch.callbacks import RichModelSummary,EarlyStopping
import gradio as gr
from collections import OrderedDict as Odict
import numpy as np
from tqdm.contrib.concurrent import process_map,thread_map
from tqdm import tqdm
from torchmetrics.regression import MeanAbsolutePercentageError,MeanAbsoluteError,MeanSquaredError
from torchmetrics.classification import MulticlassAveragePrecision
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from typing import Literal
import gc
torch.set_float32_matmul_precision('high')

In [2]:
class AlzDataset(Dataset):
    def __init__(self,x,y):
        super().__init__()
        self.x=x
        self.y=y
    def __len__(self):
        return self.x.shape[0]
    def __getitem__(self, index):
        return self.x[index],self.y[index]


class LitAlzData(L.LightningDataModule):
    def __init__(self,root='./dataset',transform=None,batch_size=128):
        super().__init__()
        self.root=root
        self.dirs=[]
        self.f=True
        self.class_map=[]
        self.x=[]
        self.y=[]
        self.total=0
        self.transform=transform
        self.workers=os.cpu_count()
        self.batch_size=batch_size
    def prepare_data(self):
        self.data=Odict()
        for root,dirs,files in os.walk(self.root):
            if len(dirs):
                if self.f:
                    self.dirs=[os.path.join(root,i) for i in dirs]
                    self.f=False
                    continue
                else:
                    print('prepare_data has aldready been run')
                    return
            self.data[root.split('/')[-1]]={'names':[files]}
        self.dirs.reverse()
        print(self.dirs)
        res=process_map(self.read_dir,
                        self.dirs,
                        tqdm_class=tqdm,
                        total=len(self.dirs),
                        desc='reading folders',
                        )
        for i in enumerate(self.data.keys()):
            self.data[i[1]]['images']=res[i[0]]
            self.total+=len(res[i[0]])
        self.class_map={i[1]:i[0] for i in enumerate(self.data.keys())}
        with tqdm(range(self.total),desc='adding to dataset') as pbar:
            for i in self.data.keys():
                for j in self.data[i]['images']:
                    if self.transform:
                        self.x.append(self.transform(j))
                    else:
                        self.x.append(j)
                    self.y.append(self.class_map[i])
                    pbar.update(1)
        self.x=torch.tensor(np.array(self.x)).repeat_interleave(3,1).float()
        self.y=torch.tensor(np.array(self.y))
        self.y=F.one_hot(self.y,4).float()
                
    def read_dir(self,root):
        # np.array(list(repeat(root+"/",times=len(os.listdir(root)))))+np.array(os.listdir(root))
        files=np.array([f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))])
        prefixes=np.repeat(root+"/",len(files))
        paths=list(prefixes+files)        
        # paths=root+"/"+np.array(os.listdir(root))
        imgs=thread_map(self.read_img,
                        paths,
                        tqdm_class=tqdm,
                        desc=f'reading files in {root}',
                        )
        return imgs

    def read_img(self,file):
        return torchvision.io.read_image(file)
    
    def setup(self,stage='fit'):
        self.train_ds,self.val_ds,self.test_ds=random_split(
            AlzDataset(self.x,self.y),
            [0.8,0.1,0.1], 
            generator=torch.Generator().manual_seed(42)
        )
        
    def train_dataloader(self):
        return DataLoader(self.train_ds,shuffle=True,num_workers=self.workers,batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.val_ds,shuffle=False,num_workers=self.workers,batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.test_ds,shuffle=False,num_workers=self.workers,batch_size=self.batch_size)
    
    def predict_dataloader(self):
        return self.val_dataloader()
alzdata=LitAlzData(transform=Compose((RandomAutocontrast(),GaussianBlur(5,(0.01,2)))))
alzdata.prepare_data()
alzdata.x.shape,alzdata.y.shape

['./dataset/MildDemented', './dataset/VeryMildDemented', './dataset/NonDemented', './dataset/ModerateDemented']


reading files in ./dataset/ModerateDemented: 100%|██████████| 52/52 [00:00<00:00, 180699.10it/s]
reading files in ./dataset/MildDemented: 100%|██████████| 717/717 [00:00<00:00, 13354.10it/s]
reading files in ./dataset/VeryMildDemented: 100%|██████████| 1790/1790 [00:00<00:00, 14644.23it/s]
reading files in ./dataset/NonDemented: 100%|██████████| 2560/2560 [00:00<00:00, 9665.00it/s] 
reading folders: 100%|██████████| 4/4 [00:03<00:00,  1.10it/s]
adding to dataset: 100%|██████████| 5119/5119 [00:02<00:00, 1982.49it/s]


(torch.Size([5119, 3, 208, 176]), torch.Size([5119, 4]))

In [3]:
class LitClasssifer(L.LightningModule):
    def __init__(self,classes=4):
        super().__init__()
        self.classes=classes
        self.metrics=[MulticlassAveragePrecision(4),MeanAbsolutePercentageError(),MeanAbsoluteError(),MeanSquaredError()]
        self.model=resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.model.fc=nn.Sequential(nn.Linear(512,classes))
        self.loss_fn=nn.CrossEntropyLoss()

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters(),lr=1e-4)

    def forward(self,x):
        return self.model(x)
    
    def log_metrics(self,stage:Literal['fit','val','test'],outputs,y,loss):
        self.log(f'{stage}_loss',loss,logger=True,prog_bar=True)
        for i in enumerate(self.metrics):
            if i[0]==0:
                i=i[1]
                self.log(f'{stage}_{i._get_name()}',i.to(outputs)(outputs,y.argmax(1)),logger=True,prog_bar=True)
            else:
                i=i[1]
                self.log(f'{stage}_{i._get_name()}',i.to(outputs)(outputs,y),logger=True,prog_bar=True)

    def perform_step(self,batch,stage:Literal['fit','val','test']):
        x,y=batch
        outputs=self.model(x)
        loss=self.loss_fn(outputs,y)
        self.log_metrics(stage,outputs,y,loss)
        return loss

    def training_step(self,batch,batch_idx):
        return self.perform_step(batch,'fit')
    
    def validation_step(self,batch,batch_idx):
        self.perform_step(batch,'val')
    
    def test_step(self,batch,batch_idx):
        self.perform_step(batch,'test')
        
    def predict_step(self,batch):
        x,_=batch
        outputs=self.model(x)
        return outputs
model=LitClasssifer()
model

LitClasssifer(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [4]:
trainer=L.Trainer(
    max_epochs=20,
    callbacks=[RichModelSummary(2),EarlyStopping(monitor='val_loss',min_delta=0.1)],
    log_every_n_steps=32
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [5]:
trainer.fit(
    model,
    datamodule=alzdata
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


prepare_data has aldready been run


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [6]:
trainer.validate(
    model,
    datamodule=alzdata
)
trainer.test(
    model,
    datamodule=alzdata
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


prepare_data has aldready been run


Validation: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


prepare_data has aldready been run


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.12214943021535873,
  'test_MulticlassAveragePrecision': 0.9932016730308533,
  'test_MeanAbsolutePercentageError': 1629736.875,
  'test_MeanAbsoluteError': 2.7420859336853027,
  'test_MeanSquaredError': 9.525773048400879}]

In [7]:
torch.cuda.empty_cache()
gc.collect()

475