<a href="https://colab.research.google.com/github/LekiWangmo/AiDermitology/blob/main/Untitled29.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
!pip install -q torch_snippets pytorch_model_summary

In [None]:
from torch_snippets import *
import torch
from torchvision import transforms
from sklearn.model_selection import train_test_split
# define the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
tfms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224,
0.225])
])

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataloader import default_collate

# Define the transformation for image and mask
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Resize((224, 224))  # Resize images to a fixed size (224x224)
])

class VOCSegData(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform

        # Load VOC dataset from torchvision
        self.voc_dataset = datasets.VOCSegmentation(root=root_dir, year='2012', image_set=split, download=True)

    def __len__(self):
        return len(self.voc_dataset)

    def __getitem__(self, idx):
        image, mask = self.voc_dataset[idx]

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)  # Apply same transformation to mask

        mask = mask.squeeze(0).long()  # Remove channel dimension and convert to LongTensor

        return image, mask

def collate_fn(batch):
    # Custom collate_fn to handle batch formation
    images, masks = zip(*batch)

    # Stack all the images and masks
    images = torch.stack(images, dim=0)
    masks = torch.stack(masks, dim=0)

    return images, masks



In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF

# Initialize the dataset
root_dir = './data'  # or wherever you're storing it
trn_ds = VOCSegData(root_dir=root_dir, split='train', transform=transform)

# Get the 4th sample
image, mask = trn_ds[3]  # Index 3 means 4th sample

# Convert tensor to PIL image for visualization (if needed)
image_np = TF.to_pil_image(image)
mask_np = TF.to_pil_image(mask.to(torch.uint8))


# Plotting
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.title("Image")
plt.imshow(image_np)
plt.axis('off')

# plt.subplot(1, 2, 2)
# plt.title("Segmentation Mask")
# plt.imshow(mask_np, cmap='gray')  # Use cmap for grayscale masks
# plt.axis('off')

plt.tight_layout()
plt.show()


In [None]:
from torch.utils.data import DataLoader

# Set up the train and validation data loaders
trn_ds = VOCSegData(root_dir='./data', split='train', transform=transform)
val_ds = VOCSegData(root_dir='./data', split='val', transform=transform)

trn_dl = DataLoader(trn_ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate_fn)


100%|██████████| 2.00G/2.00G [01:55<00:00, 17.3MB/s]


In [None]:
import torch.nn as nn

# Convolution Block
def conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

# Upsampling Convolution Block
def up_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
        nn.ReLU(inplace=True)
    )


In [None]:
import torch
import torch.nn as nn
from torchvision.models import vgg16_bn

# U-Net Architecture with VGG16-BN Encoder
class UNet(nn.Module):
    def __init__(self, pretrained=True, out_channels=12):
        super().__init__()

        self.encoder = vgg16_bn(pretrained=pretrained).features

        # Encoder blocks
        self.block1 = nn.Sequential(*self.encoder[:6])     # Conv1
        self.block2 = nn.Sequential(*self.encoder[6:13])   # Conv2
        self.block3 = nn.Sequential(*self.encoder[13:20])  # Conv3
        self.block4 = nn.Sequential(*self.encoder[20:27])  # Conv4
        self.block5 = nn.Sequential(*self.encoder[27:34])  # Conv5

        self.bottleneck = nn.Sequential(*self.encoder[34:])  # Remaining layers of VGG
        self.conv_bottleneck = conv(512, 1024)  # Custom bottleneck conv

        # Decoder blocks
        self.up_conv6 = up_conv(1024, 512)
        self.conv6 = conv(512 + 512, 512)

        self.up_conv7 = up_conv(512, 256)
        self.conv7 = conv(256 + 512, 256)

        self.up_conv8 = up_conv(256, 128)
        self.conv8 = conv(128 + 256, 128)

        self.up_conv9 = up_conv(128, 64)
        self.conv9 = conv(64 + 128, 64)

        self.up_conv10 = up_conv(64, 32)
        self.conv10 = conv(32 + 64, 32)

        self.conv11 = nn.Conv2d(32, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)

        # Bottleneck
        bottleneck = self.bottleneck(block5)
        x = self.conv_bottleneck(bottleneck)

        # Decoder
        x = self.up_conv6(x)
        x = torch.cat([x, block5], dim=1)
        x = self.conv6(x)

        x = self.up_conv7(x)
        x = torch.cat([x, block4], dim=1)
        x = self.conv7(x)

        x = self.up_conv8(x)
        x = torch.cat([x, block3], dim=1)
        x = self.conv8(x)

        x = self.up_conv9(x)
        x = torch.cat([x, block2], dim=1)
        x = self.conv9(x)

        x = self.up_conv10(x)
        x = torch.cat([x, block1], dim=1)
        x = self.conv10(x)

        x = self.conv11(x)
        return x


In [None]:
import torch
import torch.nn as nn

# Cross-Entropy Loss for multi-class segmentation
# ce = nn.CrossEntropyLoss()

import torch.nn.functional as F

# CrossEntropyLoss (already handles log-softmax + NLL loss internally)
ce = nn.CrossEntropyLoss()

def UnetLoss(preds, targets):
    # preds: shape (N, C, H, W), targets: shape (N, H, W)
    ce_loss = ce(preds, targets.long())  # Ensure targets are LongTensor for CE
    acc = (torch.argmax(preds, dim=1) == targets).float().mean()
    return ce_loss, acc



In [None]:
# def train_batch(model, data, optimizer, criterion):
#     model.train()  # Set the model to training mode
#     ims, ce_masks = data  # Unpack the batch into input images and ground truth masks
#     _masks = model(ims)   # Forward pass to get predictions
#     optimizer.zero_grad()  # Clear previous gradients
#     loss, acc = criterion(_masks, ce_masks)  # Compute loss and accuracy
#     loss.backward()  # Backpropagate the loss
#     optimizer.step()  # Update model parameters
#     return loss.item(), acc.item()  # Return scalar values for tracking

# def train_batch(model, data, optimizer, criterion):
#     model.train()
#     ims, ce_masks = data
#     ims, ce_masks = ims.to(device), ce_masks.to(device)  # Move to same device as model

#     _masks = model(ims)
#     optimizer.zero_grad()
#     loss, acc = criterion(_masks, ce_masks)
#     loss.backward()
#     optimizer.step()

def train_batch(model, data, optimizer, criterion):
    model.train()
    ims, ce_masks = data #Unpacks the input data tuple into
    _masks = model(ims) #Passes the input images through the
    optimizer.zero_grad() # Clears the gradients of all
    loss, acc = criterion(_masks, ce_masks) #Calculates the
    loss.backward() #Backpropagates the gradients.
    optimizer.step()
    return loss.item(), acc.item()

    return loss.item(), acc.item()



In [None]:
# @torch.no_grad()
# def validate_batch(model, data, criterion):
#     model.eval()
#     ims, masks = data
#     ims, masks = ims.to(device), masks.to(device)

#     outputs = model(ims)
#     loss = criterion(outputs, masks)

#     preds = torch.argmax(outputs, dim=1)
#     acc = (preds == masks).float().mean()  # Works if dims match

#     return loss.item(), acc.item()


# @torch.no_grad()
# def validate_batch(model, data, criterion):
#     model.eval()
#     ims, masks = data
#     ims, masks = ims.to(device), masks.to(device)

#     outputs = model(ims)
#     loss, acc = criterion(outputs, masks)  # Unpack both here
#     return loss.item(), acc.item()

@torch.no_grad()#Decorator that disables gradient computation
def validate_batch(model, data, criterion):
    model.eval()
    ims, masks = data
    _masks = model(ims)
    loss, acc = criterion(_masks, masks)
    return loss.item(), acc.item()



In [None]:
model = UNet().to(device)
criterion = UnetLoss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
n_epochs = 20


Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth
100%|██████████| 528M/528M [00:03<00:00, 154MB/s]


In [None]:
# class Report:
#     def __init__(self, n_epochs):
#         self.n_epochs = n_epochs
#         self.epoch_data = []

#     def record(self, step, **metrics):
#         self.epoch_data.append((step, metrics))  # No printing per batch

#     def report_avgs(self, epoch):
#         trn_loss, trn_acc, val_loss, val_acc = [], [], [], []
#         for _, data in self.epoch_data:
#             if 'trn_loss' in data: trn_loss.append(data['trn_loss'])
#             if 'trn_acc' in data: trn_acc.append(data['trn_acc'])
#             if 'val_loss' in data: val_loss.append(data['val_loss'])
#             if 'val_acc' in data: val_acc.append(data['val_acc'])

#         print(f"Epoch {epoch}/{self.n_epochs}:")
#         if trn_loss:
#             print(f"  Train Loss: {sum(trn_loss)/len(trn_loss):.4f}, Accuracy: {sum(trn_acc)/len(trn_acc):.4f}")
#         if val_loss:
#             print(f"  Val   Loss: {sum(val_loss)/len(val_loss):.4f}, Accuracy: {sum(val_acc)/len(val_acc):.4f}")

#         self.epoch_data = []  # Reset for next epoch


In [None]:
# the purpose of Report class is to keep track/log and display
# training and validation metrics during training.
log = Report(n_epochs)
for ex in range(n_epochs): #Iterates over the specified number of epochs.
    N = len(trn_dl) #Sets N to the total number of batches in the training data loader (trn_dl).
    for bx, data in enumerate(trn_dl): #Iterates over batches in the training data loader.
      loss, acc = train_batch(model, data, optimizer, criterion)
      #Performs a training step on the current batch and obtains the training loss and accuracy.
      log.record(ex+(bx+1)/N, trn_loss=loss, trn_acc=acc, end='\r') #Records the training loss and accuracy for the current batch The code then repeats a similar process for the validation set.
    N = len(val_dl)
    for bx, data in enumerate(val_dl):
      loss, acc = validate_batch(model, data, criterion)
      log.record(ex+(bx+1)/N, val_loss=loss, val_acc=acc, end='\r')
      #Calls the report_avgs method of the Report instance to report and display the average metrics for the epoch.
    log.report_avgs(ex+1) # specifies the current epoch.

In [None]:
log.plot_epochs(['trn_loss','val_loss'])
log.plot_epochs(['trn_acc', 'val_acc'])

AttributeError: 'Report' object has no attribute 'plot_epochs'