In [None]:
!pip install natsort
!pip install wandb 

In [None]:
!cp /kaggle/input/model-scripts/dataset.py /kaggle/working/
!cp /kaggle/input/model-scripts/model.py /kaggle/working/

In [1]:
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
from dataset import HAM10000, preload_ham10000
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import load_model
import torch.nn.functional as F
import wandb
from natsort import natsorted
from matplotlib import pyplot as plt
import numpy as np

In [2]:
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
dataset_root = '/kaggle/input/ham1000-segmentation-and-classification'
train_data, val_data = preload_ham10000(dataset_root, val_size=0.2)

In [3]:
wandb.init(project="ham10000-classification", config={
    "learning_rate": 1e-4,
    "epochs": 10,
    "batch_size": 8,
    "optimizer": "Adam"
})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhassantamerha[0m ([33mhassantamerha-alexandria-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
train_data = train_data[:1200]
val_data = val_data[:300]

In [5]:
train_dataset = HAM10000(data=train_data, transform=preprocess)
val_dataset = HAM10000(data=val_data, transform=preprocess)

train_loader = DataLoader(train_dataset, batch_size=wandb.config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=wandb.config.batch_size, shuffle=False)

print(f"Train size: {len(train_loader)}")
print(f"Test size: {len(val_loader)}")

Train size: 150
Test size: 38


In [6]:
model = load_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Total number of parameters in the model: 42,004,074
Using device: cuda


In [7]:
criterion = nn.CrossEntropyLoss() 
optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
model.to(device)
criterion.to(device)
num_epochs = wandb.config.epochs

In [8]:
best_val_acc = 0.0
patience = 3
epochs_without_improvement = 0
checkpoint_path = "best_model.pth"

In [None]:
for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, data in enumerate(tqdm(train_loader)):
            inputs, masks = data
            inputs, masks = inputs.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)['out']
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        train_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}, loss: {running_loss/len(train_loader)}")
        wandb.log({"train_loss": train_loss})
    

        model.eval()
        correct = 0
        total = 0
        val_loss = 0.0
        with torch.no_grad():
            for data in tqdm(val_loader):
                inputs, masks = data
                inputs, masks = inputs.to(device), masks.to(device)
                outputs = model(inputs)['out']

                # Resize outputs to match mask dimensions
                outputs = F.interpolate(outputs, size=masks.shape[1:], mode="bilinear", align_corners=False)

                loss = criterion(outputs, masks)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                correct += (preds == masks).sum().item()
                total += masks.numel()
        
        val_loss /= len(val_loader)
        val_acc = correct / total * 100
        print(f"Epoch {epoch + 1}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%")
        wandb.log({"val_loss": val_loss, "val_accuracy": val_acc})


        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_val_acc,
            }, checkpoint_path)
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            print(f"No improvement for {epochs_without_improvement} epoch(s).")

        if epochs_without_improvement >= patience:
            print("Early stopping triggered.")
            break

print("Training completed.")
wandb.finish()



100%|██████████| 150/150 [03:37<00:00,  1.45s/it]


Epoch 1, loss: 0.23754760786890983


100%|██████████| 38/38 [00:17<00:00,  2.18it/s]


Epoch 1, Val Loss: 0.1456, Val Accuracy: 95.17%


100%|██████████| 150/150 [03:37<00:00,  1.45s/it]


Epoch 2, loss: 0.12252996663252512


100%|██████████| 38/38 [00:17<00:00,  2.18it/s]


Epoch 2, Val Loss: 0.1235, Val Accuracy: 95.25%


100%|██████████| 150/150 [03:37<00:00,  1.45s/it]


Epoch 3, loss: 0.10154160452385744


100%|██████████| 38/38 [00:17<00:00,  2.18it/s]


Epoch 3, Val Loss: 0.1100, Val Accuracy: 95.75%


100%|██████████| 150/150 [03:37<00:00,  1.45s/it]


Epoch 4, loss: 0.08182788168390592


100%|██████████| 38/38 [00:17<00:00,  2.18it/s]


Epoch 4, Val Loss: 0.1075, Val Accuracy: 95.78%


100%|██████████| 150/150 [03:37<00:00,  1.45s/it]


Epoch 5, loss: 0.061101084326704344


100%|██████████| 38/38 [00:17<00:00,  2.18it/s]


Epoch 5, Val Loss: 0.1035, Val Accuracy: 96.02%


 75%|███████▌  | 113/150 [02:43<00:53,  1.45s/it]

In [None]:
def dice_score(pred, target):
    pred = pred.flatten()
    target = target.flatten()

    intersection = (pred * target).sum()
    dice = (2. * intersection + 1e-6) / (pred.sum() + target.sum() + 1e-6)  # Adding epsilon to avoid division by 0

    return dice.item()

In [None]:
def display_segmentation_examples(inputs, masks, outputs, num_examples=3):
    fig, axes = plt.subplots(num_examples, 3, figsize=(12, 4 * num_examples))
    for i in range(num_examples):
        ax = axes[i]
        
        # Original image
        ax[0].imshow(inputs[i].cpu().numpy().transpose(1, 2, 0))  # Convert to HWC format
        ax[0].set_title("Input Image")
        ax[0].axis('off')
        
        # Ground truth mask
        ax[1].imshow(masks[i].cpu().numpy(), cmap='gray')
        ax[1].set_title("Ground Truth")
        ax[1].axis('off')
        
        # Predicted mask
        pred_mask = torch.argmax(outputs[i], dim=0)  # Take the class with the highest probability
        ax[2].imshow(pred_mask.cpu().numpy(), cmap='gray')
        ax[2].set_title("Prediction")
        ax[2].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
def evaluate_model(val_loader, model, device):
    model.eval()  # Set the model to evaluation mode
    dice_scores = []

    with torch.no_grad():
        for i,(inputs, masks) in enumerate(val_loader):
            if i > 20:
                break
                
            inputs, masks = inputs.to(device), masks.to(device)

            # Forward pass
            outputs = model(inputs)['out']

            # Resize the output to match the mask size (if needed)
            outputs = torch.nn.functional.interpolate(outputs, size=masks.shape[1:], mode='bilinear', align_corners=False)

            # Calculate the Dice score
            pred_mask = torch.argmax(outputs, dim=1)  # Get the predicted mask
            dice = dice_score(pred_mask, masks)
            dice_scores.append(dice)

            # Display segmentation examples for the first batch
            display_segmentation_examples(inputs, masks, outputs, num_examples=1)

    # Calculate the average Dice score
    avg_dice_score = np.mean(dice_scores)
    print(f"Average Dice Score: {avg_dice_score:.4f}")

In [None]:
evaluate_model(val_loader, model, device)

from IPython.display import FileLink
FileLink(r'best_model.pth')

In [46]:
from IPython.display import FileLink
s = FileLink(r'best_model.pth')
s

In [None]:
https://kkb-production.jupyter-proxy.kaggle.net/k/211600123/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2IiwidHlwIjoiSldUIn0..aqvPDq8Wh3lpDru-0CaxOw.ap9Om70Vv_uU1PdJlz4IOl6f1yKZnM40tRpqT0Dt0f-1wiZilgU820wSOp0bkGY4br6Yd1U-rZSsT7xpdOR1oySmnbEaD_O0Rd1VPFFJDQgWt-6_ktcCEAWCei5GQV-DdkjgDcCvRt0XCP3P6_8_DnKmg-kxFohdqCQBCNkr5RR2t5FaYwuD1UWz6XqT5NbbbX0PJGqpExcq-GM7Ydw6-ANlwO5STu2GwKUeRVm4YlP0bwqAkspuJ6s2ukH5zo64.jCoSGdxH85Lqp2wRPNFg_g/proxy/files/best_model.pth