In [1]:
from torch import optim, nn, utils, Tensor
from torchvision.transforms import transforms
from torchvision.models import vision_transformer
import pytorch_lightning as PL
from pytorch_lightning.loggers import WandbLogger
import os
import kagglehub

In [2]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_api = user_secrets.get_secret("wandb-api")

In [3]:
# Login to wandb
import wandb
wandb.login(key=wandb_api, relogin=True)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [4]:
import torch 
import torchvision 
import torch.utils.data as data
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image 

class FruitsAndVeggies(Dataset):
    def __init__(self, split_root, transforms):
        # create a dict of labels and filepaths
        class_dir_names = sorted(os.listdir(split_root))
        self.n_classes = len(class_dir_names)
        self.transforms = transforms
        
        # create one-hot encoding 
        self.dataset_list = []
        for i, class_dir in enumerate(class_dir_names):
            label = torch.zeros(self.n_classes)
            label[i] = 1
            
            extension_set = {"jpg", "png",  "JPG", "jpeg"}

            for image in sorted(os.listdir(os.path.join(split_root, class_dir))):
                extension = image.split(".")[-1]

                if extension in extension_set:
                    self.dataset_list.append([label, os.path.join(split_root,class_dir,image)])
                else:
                    print(f"{extension} found in dataset")
        
    def __len__(self):
        return len(self.dataset_list)

    def __getitem__(self, index):
        data_list = self.dataset_list[index]
        image_path = data_list[1]
        label = data_list[0]
        image = Image.open(image_path)
       

        if image.mode != "RGB":
            image = image.convert("RGB")

        image = self.transforms(image)

        return image, label


In [5]:
# Download latest version
data_path = kagglehub.dataset_download("kritikseth/fruit-and-vegetable-image-recognition")

print("Path to dataset files:", data_path)

LEARNING_RATE = 1e-4*0.70
BATCH_SIZE = 32
ARCHITECTURE = "ViT_B_16"
# ViT_B_16_Weights.IMAGENET1K_V1
# ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
# ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1
# ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1
vit_model = vision_transformer.vit_b_16(weights=vision_transformer.ViT_B_16_Weights.IMAGENET1K_V1)
model_transforms = vision_transformer.ViT_B_16_Weights.IMAGENET1K_V1.transforms()

train_path = os.path.join(data_path, "train")
validation_path = os.path.join(data_path, "validation")
test_path = os.path.join(data_path, "test")

train_dataset = FruitsAndVeggies(train_path, model_transforms)
validation_dataset = FruitsAndVeggies(validation_path, model_transforms)
test_dataset = FruitsAndVeggies(test_path, model_transforms)

train_dataloader = utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
validation_dataloader = utils.data.DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_dataloader = utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

Path to dataset files: /kaggle/input/fruit-and-vegetable-image-recognition


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 200MB/s] 


In [6]:
class ViT_lightning(PL.LightningModule):
    def __init__(self, model, num_classes, lr):
        super().__init__()
        
        self.model = model
        self.model.heads.head = nn.Linear(self.model.heads.head.in_features, num_classes)
        
        self.num_classes = num_classes
        #self.loss_fn = nn.CrossEntropyLoss()
        self.loss_fn = nn.functional.cross_entropy
        self.lr = lr

    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        #optimizer = optim.Adam(self.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=0.1)
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=0.01)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        preds = self(x)

        # Convert one-hot labels to class indices
        if y.ndim == 2 and y.shape[1] > 1:  # One-hot encoded case
            y = y.argmax(dim=1)  # Convert [batch_size, num_classes] → [batch_size]

        loss = self.loss_fn(preds, y)
        acc = (preds.argmax(dim=1) == y).float().mean()
        
        self.log("train_loss", loss, prog_bar=True, sync_dist=True)
        self.log("train_acc", acc, prog_bar=True, sync_dist=True)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        preds = self(x)

        # Convert one-hot labels to class indices
        if y.ndim == 2 and y.shape[1] > 1:  # One-hot encoded case
            y = y.argmax(dim=1)  # Convert [batch_size, num_classes] → [batch_size]

        loss = self.loss_fn(preds, y)
        acc = (preds.argmax(dim=1) == y).float().mean()
        
        self.log("val_loss", loss, prog_bar=True, sync_dist=True)
        self.log("val_acc", acc, prog_bar=True, sync_dist=True)

    def _step(self, train_batch, batch_idx):
        x, y = train_batch
        preds = self(x)

        # Convert one-hot labels to class indices
        if y.ndim == 2 and y.shape[1] > 1:  # One-hot encoded case
            y = y.argmax(dim=1)  # Convert [batch_size, num_classes] → [batch_size]

        loss = self.loss_fn(preds, y)
        acc = (preds.argmax(dim=1) == y).float().mean()
        
        self.log("train_loss", loss, prog_bar=True, sync_dist=True)
        self.log("train_acc", acc, prog_bar=True, sync_dist=True)
        return loss

In [7]:
# Model creation
model = ViT_lightning(
    model=vit_model,
    num_classes=train_dataset.n_classes,
    lr=LEARNING_RATE,
)

In [8]:
wandb_logger = WandbLogger(log_model="all", name=f"{ARCHITECTURE}-dreez", entity="avs-846", project="Mini-project")
trainer = PL.Trainer(max_epochs=5, 
                     accelerator="gpu", 
                     devices=2, 
                     precision=32, 
                     log_every_n_steps=10,
                     logger=wandb_logger,
                    )

wandb_logger.experiment.config.update({
    "batch_size": BATCH_SIZE,
    "num_classes": train_dataset.n_classes,
    "learning_rate": LEARNING_RATE,
    "optimizer": "AdamW",
    "architecture": ARCHITECTURE,
    "dataset": "fruit-and-vegetable-image-recognition",
})

trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloader)
#trainer.test(model=model, dataloaders=test_dataloader)
#lightning_fabric.utilities.exceptions.MisconfigurationException: No `test_step()` method defined to run `Trainer.test`.

wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mandreas-hovaldt[0m ([33mavs-846[0m). Use [1m`wandb login --relogin`[0m to force relogin




0,1
epoch,▁▁▁▁▁▃▃▃▃▃▃▅▅▅▅▅▅▆▆▆▆▆▆██████
train_acc,▁▆▇▆▇▆▇▇▇▇██▇▇███▇██████
train_loss,█▅▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇███
val_acc,▁▅▇▇█
val_loss,█▃▁▁▁

0,1
epoch,4.0
train_acc,0.98438
train_loss,0.06869
trainer/global_step,244.0
val_acc,0.96875
val_loss,0.14216


In [9]:
wandb.finish(quiet=False)