In [None]:
!pip install datasets transformers timm lightning torchmetrics -q

In [None]:
import lightning as L
import torch
from torch import nn
import torchmetrics
from torch.utils.data import Dataset, DataLoader
import timm
import transformers
import datasets
import torchvision.transforms as T

In [None]:
# make a datamodule
from datasets import load_dataset

dataset = load_dataset("food101")
full = dataset['train'].train_test_split(test_size=0.1)
ds = datasets.DatasetDict({
    "train" : full['train'],
    "validation" : dataset['validation'],
    "test" : full["test"]
})

In [None]:
ds

In [None]:
id2label = {i:l for i,l in enumerate(ds['train'].features['label'].names)}
label2id = {v:k for k,v in id2label.items()}

In [None]:
img_size = (224,224)

train_tf = T.Compose([
    T.Resize(img_size),
    T.RandomRotation(30),
    T.RandomHorizontalFlip(),
    T.CenterCrop(img_size),
    T.ToTensor(),
    T.Normalize(
        mean = (0.5,0.5,0.5),
        std = (0.5,0.5,0.5)
    ),
])

test_tf = T.Compose([
    T.Resize(img_size),
    T.ToTensor(),
    T.Normalize(
        mean = (0.5,0.5,0.5),
        std = (0.5,0.5,0.5)
    )
])

In [None]:
def train_transform(batch):
    processed = {}
    processed['pixel_values'] = torch.stack([train_tf(img.convert('RGB'))for img in batch['image']],dim=0)
    processed['label'] = torch.tensor(batch['label'])
    return processed

def val_transform(batch):
    processed = {}
    processed['pixel_values'] = torch.stack([test_tf(img.convert('RGB'))for img in batch['image']],dim=0)
    processed['label'] = torch.tensor(batch['label'])
    return processed

def test_transform(batch):
    processed = {}
    processed['pixel_values'] = torch.stack([test_tf(img.convert('RGB'))for img in batch['image']],dim=0)
    processed['label'] = torch.tensor(batch['label'])
    return processed

In [None]:
train_ds = ds['train'].with_transform(train_transform)
val_ds = ds['validation'].with_transform(val_transform)
test_ds = ds['test'].with_transform(test_transform)
len(train_ds),len(val_ds),len(test_ds)

In [None]:
batch_size=64
train_dl = DataLoader(train_ds,batch_size=batch_size,shuffle=True,num_workers=4)
val_dl = DataLoader(val_ds,batch_size=batch_size,shuffle=False,num_workers=4)
test_dl = DataLoader(test_ds,batch_size=batch_size,shuffle=False,num_workers=4)

In [None]:
model_name = 'swin_s3_base_224'
model = timm.create_model(model_name=model_name,pretrained=True,num_classes=len(id2label))

In [None]:
class Swin_Lightning(L.LightningModule):
    def __init__(self,model,lr):
        super().__init__()
        self.model = model
        self.lr = lr
        self.training_acc = torchmetrics.Accuracy(task='multiclass',num_classes=len(id2label))
        self.validation_acc = torchmetrics.Accuracy(task='multiclass',num_classes=len(id2label))
        self.test_acc = torchmetrics.Accuracy(task='multiclass',num_classes=len(id2label))
    def forward(self,x):
        return self.model(x)
    def training_step(self,batch,batch_idx):
        logits = self(batch['pixel_values'])
        loss = torch.nn.functional.cross_entropy(logits,batch['label'])
        preds = logits.argmax(dim=-1)
        self.training_acc(preds,batch['label'])
        self.log('training_loss',loss,prog_bar=True,on_step=False,on_epoch=True,sync_dist=True)
        self.log('training_acc',self.training_acc,prog_bar=True,on_step=False,on_epoch=True,sync_dist=True)
        return loss
    def validation_step(self,batch,batch_idx):
        logits = self(batch['pixel_values'])
        loss = torch.nn.functional.cross_entropy(logits,batch['label'])
        preds = logits.argmax(dim=-1)
        self.validation_acc(preds,batch['label'])
        self.log('validation_loss',loss,prog_bar=True,on_step=False,on_epoch=True,sync_dist=True)
        self.log('validation_acc',self.validation_acc,prog_bar=True,on_step=False,on_epoch=True,sync_dist=True)
    def test_step(self,batch,batch_idx):
        logits = self(batch['pixel_values'])
        preds = logits.argmax(dim=-1)
        self.test_acc(preds,batch['label'])
        self.log('test_acc',self.test_acc,prog_bar=True,on_step=False,on_epoch=True,sync_dist=True)
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(),lr=self.lr)
        return opt

In [None]:
l_model = Swin_Lightning(model=model,lr=3e-4)
numepochs = 3
trainer = L.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=numepochs,
    precision='16-mixed',
)

In [None]:
trainer.fit(
    model=l_model,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
)

In [None]:
trainer.test(dataloaders=val_dl)

In [None]:
from huggingface_hub import login
login()

In [None]:
timm.models._hub.push_to_hf_hub(
    model,
    'swin_s3_base_224-Foods-101',
    model_config={'labels': [v for k,v in id2label.items()] }
)