In [None]:
import os
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets
from sklearn.model_selection import train_test_split
import timm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
num_classes = 4
batch_size = 32
num_epochs = 25
learning_rate = 1e-4

In [None]:
DATA_DIR  = "/kaggle/input/ct-kidney-dataset-normal-cyst-tumor-and-stone/CT-KIDNEY-DATASET-Normal-Cyst-Tumor-Stone/CT-KIDNEY-DATASET-Normal-Cyst-Tumor-Stone"
full_dataset = datasets.ImageFolder(DATA_DIR)

train_idx, val_idx = train_test_split(
    np.arange(len(full_dataset)),
    test_size=0.2,
    stratify=full_dataset.targets,
    random_state=42
)

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  
                         std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [None]:
train_dataset = Subset(datasets.ImageFolder(DATA_DIR, transform=train_transforms), train_idx)
val_dataset   = Subset(datasets.ImageFolder(DATA_DIR, transform=test_transforms), val_idx)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
model = timm.create_model("efficientvit_m2", pretrained=True, num_classes=num_classes)
model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [None]:
best_acc = 0.0

In [None]:
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 30)
    
    model.train()
    running_loss = 0.0
    running_corrects = 0
    
    for inputs, labels in tqdm(train_loader, desc="Training", leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
    
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)
    print(f"Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f}")

In [None]:
model.eval()
test_loss = 0.0
test_corrects = 0
    
with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Testing", leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)
        
        test_loss += loss.item() * inputs.size(0)
        test_corrects += torch.sum(preds == labels.data)

test_epoch_loss = test_loss / len(test_dataset)
test_epoch_acc = test_corrects.double() / len(test_dataset)
print(f"Test Loss: {test_epoch_loss:.4f} | Test Acc: {test_epoch_acc:.4f}")

In [None]:
if test_epoch_acc > best_acc:
    best_acc = test_epoch_acc
    best_model_wts = model.state_dict()
    torch.save(best_model_wts, "efficientvit_m2_kidney_disease_classifier.pth")
    print("==> Best model saved!")
    
scheduler.step()

In [None]:
model.load_state_dict(torch.load("efficientvit_m2_kidney_disease_classifier.pth"))
model.eval()

In [None]:
test_corrects = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        test_corrects += torch.sum(preds == labels.data)

test_acc = test_corrects.double() / len(test_dataset)
print(f"Final Test Accuracy: {test_acc:.4f}")