In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
from segment_anything import SamPredictor, sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.models import resnet50, ResNet50_Weights, VGG16_Weights
from tqdm import tqdm



In [2]:
class MedicalImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

image_dataset = MedicalImageDataset('/media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/train/imgs', transform=transform)

In [3]:
# Calculate lengths for the splits
train_size = int(0.95 * len(image_dataset))
val_size = len(image_dataset) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(image_dataset, [train_size, val_size])

In [4]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [5]:
model_type = 'vit_h'
checkpoint = '/media/rohit/mirlproject2/fetal head circumference/sam_vit_h_4b8939.pth'
device = 'cuda'

sam_model = sam_model_registry[model_type](checkpoint=checkpoint)
sam_model.to(device)

sam_encoder = sam_model.image_encoder
sam_encoder.eval()

ImageEncoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1280, out_features=3840, bias=True)
        (proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (lin1): Linear(in_features=1280, out_features=5120, bias=True)
        (lin2): Linear(in_features=5120, out_features=1280, bias=True)
        (act): GELU(approximate='none')
      )
    )
    (1): Block(
      (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1280, out_features=3840, bias=True)
        (proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (norm2): LayerNorm((1280,), eps=1e-06, eleme

In [6]:
import torch
import torch.nn as nn
from torchvision import models

class StudentResNet(nn.Module):
    def __init__(self):
        super(StudentResNet, self).__init__()
        resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.encoder = nn.Sequential(*list(resnet.children())[:-2])  # Use layers except the last two
        
        # Adjust the channel size from 2048 to 256
        self.adjust_channels = nn.Conv2d(2048, 256, kernel_size=1)
        
        # Add upsampling layers to increase spatial dimensions to [64, 64]
        self.upsample = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),  # [14, 14]
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),  # [28, 28]
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),  # [56, 56]
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(size=(64, 64), mode='bilinear', align_corners=False)  # [64, 64]
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.adjust_channels(x)
        x = self.upsample(x)
        return x

student_model = StudentResNet().to(device)

In [7]:
import torch.nn.functional as F
from torchvision import models



class PerceptualLoss(nn.Module):
    def __init__(self, layers):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg16(weights=VGG16_Weights.DEFAULT).features
        self.layers = layers
        self.feature_extractor = nn.ModuleList()
        start_layer = 0
        for end_layer in self.layers:
            self.feature_extractor.append(nn.Sequential(*list(vgg.children())[start_layer:end_layer]))
            start_layer = end_layer
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        # Map of expected input channels at different layers
        self.expected_channels = [3, 64, 128, 256, 512]

    def forward(self, student_outputs, teacher_outputs):
        loss = 0

        for i, extractor in enumerate(self.feature_extractor):
            expected_channels = self.expected_channels[i]

            if student_outputs.size(1) != expected_channels:
                student_outputs = nn.Conv2d(student_outputs.size(1), expected_channels, kernel_size=1).to(student_outputs.device)(student_outputs)
                teacher_outputs = nn.Conv2d(teacher_outputs.size(1), expected_channels, kernel_size=1).to(teacher_outputs.device)(teacher_outputs)

            # Resize outputs to match the expected VGG input size at this layer
            target_size = (student_outputs.size(2), student_outputs.size(3))
            student_outputs_resized = F.interpolate(student_outputs, size=target_size, mode='bilinear', align_corners=False)
            teacher_outputs_resized = F.interpolate(teacher_outputs, size=target_size, mode='bilinear', align_corners=False)

            student_features = extractor(student_outputs_resized)
            teacher_features = extractor(teacher_outputs_resized)
            loss += F.mse_loss(student_features, teacher_features)
        return loss

In [8]:
class CombinedDistillationLoss(nn.Module):
    def __init__(self, teacher, layers):
        super(CombinedDistillationLoss, self).__init__()
        self.teacher = teacher
        self.criterion_mse = nn.MSELoss()
        self.criterion_perceptual = PerceptualLoss(layers)

    def forward(self, student_outputs, images):

        tensor_image = images.squeeze(0)
        image = tensor_image.permute(1, 2, 0).cpu().numpy()
        image = (image * 255).astype(np.uint8)

        transform = ResizeLongestSide(sam_model.image_encoder.img_size)
        input_image = transform.apply_image(image)
        input_image_torch = torch.as_tensor(input_image, device=device)
        #input_image_torch = torch.as_tensor(input_image)
        transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]

        input_image = sam_model.preprocess(transformed_image)
        with torch.no_grad():
            teacher_outputs = self.teacher(input_image).detach()

        mse_loss = self.criterion_mse(student_outputs, teacher_outputs)
        perceptual_loss = self.criterion_perceptual(student_outputs, teacher_outputs)
        loss = mse_loss + perceptual_loss
        return loss

# Instantiate the combined loss function
layers = [4, 9, 16, 23, 30]  # Corresponding to relu1_2, relu2_2, relu3_3, relu4_3, relu5_3 in VGG16
criterion = CombinedDistillationLoss(sam_encoder, layers).to(device)

In [9]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pth', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path  # Path to save the checkpoint
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        # Round the validation loss to 5 decimal places
        val_loss = round(val_loss, 4)
        score = -val_loss  # Convert to a maximization problem

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decreases."""
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.4f} -------> {val_loss:.4f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)  # Save the model state to the specified path
        self.val_loss_min = val_loss



In [10]:
training_loss = []
validation_loss = []

In [11]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np

# Ensure these variables are defined
training_loss = []
validation_loss = []

# Optimizer
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=1e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

# Early stopping with a specific path to save the best model
early_stopping = EarlyStopping(patience=10, verbose=True, path='/media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/weights/resnet50_early_stop.pth')

# Training loop
def train_model(student_model, teacher_model, train_loader, val_loader, criterion, optimizer, num_epochs=100):
    best_val_loss = float('inf')
    student_model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        
        # Training phase
        for images in train_loader_tqdm:
            images = images.to(device)

            optimizer.zero_grad()
            student_outputs = student_model(images)
            loss = criterion(student_outputs, images)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_loader_tqdm.set_postfix({"Train Loss": running_loss / (train_loader_tqdm.n + 1)})

        avg_train_loss = running_loss / len(train_loader)  # Average loss per batch
        training_loss.append(avg_train_loss)
        #print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss}')

        val_running_loss = 0.0
        student_model.eval()
        val_loader_tqdm = tqdm(val_loader, desc=f"Validation {epoch+1}/{num_epochs}", unit="batch")

        # Validation phase
        with torch.no_grad():
            for val_images in val_loader_tqdm:
                val_images = val_images.to(device)

                student_outputs = student_model(val_images)
                loss = criterion(student_outputs, val_images)

                val_running_loss += loss.item()
                val_loader_tqdm.set_postfix({"Validation Loss": val_running_loss / (val_loader_tqdm.n + 1)})

        avg_val_loss = val_running_loss / len(val_loader)  # Average loss per batch
        validation_loss.append(avg_val_loss)
        print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss}, Validation Loss: {avg_val_loss}')

        # Check if the current validation loss is the best we've seen so far
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(student_model.state_dict(), '/media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/weights/resnet50_best.pth')
            print(f'Saving best model with validation loss-------------------------------------> {best_val_loss}')

        scheduler.step(avg_val_loss)

        # Print the current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Learning rate: {current_lr:.6f}")
        
        torch.save(student_model.state_dict(), '//media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/weights/resnet50_latest.pth')

        # Early stopping
        early_stopping(avg_val_loss, student_model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

        student_model.train()

train_model(student_model, sam_encoder, train_loader, val_loader, criterion, optimizer, num_epochs=100)


Epoch 1/100: 100%|██████████| 567/567 [05:39<00:00,  1.67batch/s, Train Loss=1.97]
Validation 1/100: 100%|██████████| 30/30 [00:17<00:00,  1.74batch/s, Validation Loss=465]    


Epoch 1, Training Loss: 1.9722788001704679, Validation Loss: 465.3054066936175
Saving best model with validation loss-------------------------------------> 465.3054066936175
Learning rate: 0.000100
Validation loss decreased (inf -------> 465.3054).  Saving model ...


Epoch 2/100: 100%|██████████| 567/567 [04:48<00:00,  1.97batch/s, Train Loss=1.99]
Validation 2/100: 100%|██████████| 30/30 [00:14<00:00,  2.08batch/s, Validation Loss=266]   


Epoch 2, Training Loss: 1.9851453178143375, Validation Loss: 266.253952729702
Saving best model with validation loss-------------------------------------> 266.253952729702
Learning rate: 0.000100
Validation loss decreased (465.3054 -------> 266.2540).  Saving model ...


Epoch 3/100: 100%|██████████| 567/567 [04:43<00:00,  2.00batch/s, Train Loss=1.97]
Validation 3/100: 100%|██████████| 30/30 [00:14<00:00,  2.08batch/s, Validation Loss=1.2e+3] 


Epoch 3, Training Loss: 1.973053394802033, Validation Loss: 1199.3242646773656
Learning rate: 0.000100
EarlyStopping counter: 1 out of 10


Epoch 4/100: 100%|██████████| 567/567 [04:43<00:00,  2.00batch/s, Train Loss=1.98]
Validation 4/100: 100%|██████████| 30/30 [00:14<00:00,  2.08batch/s, Validation Loss=1.47e+3]


Epoch 4, Training Loss: 1.9750590551467169, Validation Loss: 1468.857564496994
Learning rate: 0.000100
EarlyStopping counter: 2 out of 10


Epoch 5/100: 100%|██████████| 567/567 [04:43<00:00,  2.00batch/s, Train Loss=1.98]
Validation 5/100: 100%|██████████| 30/30 [00:14<00:00,  2.08batch/s, Validation Loss=304]   


Epoch 5, Training Loss: 1.975257945018681, Validation Loss: 304.2143719991048
Learning rate: 0.000100
EarlyStopping counter: 3 out of 10


Epoch 6/100: 100%|██████████| 567/567 [04:43<00:00,  2.00batch/s, Train Loss=1.98]
Validation 6/100: 100%|██████████| 30/30 [00:14<00:00,  2.08batch/s, Validation Loss=1.42e+3]


Epoch 6, Training Loss: 1.977608647716529, Validation Loss: 1415.5055540919304
Learning rate: 0.000100
EarlyStopping counter: 4 out of 10


Epoch 7/100: 100%|██████████| 567/567 [04:43<00:00,  2.00batch/s, Train Loss=1.99]
Validation 7/100: 100%|██████████| 30/30 [00:14<00:00,  2.08batch/s, Validation Loss=47.7]


Epoch 7, Training Loss: 1.993566097616098, Validation Loss: 47.66214406490326
Saving best model with validation loss-------------------------------------> 47.66214406490326
Learning rate: 0.000100
Validation loss decreased (266.2540 -------> 47.6621).  Saving model ...


Epoch 8/100: 100%|██████████| 567/567 [04:43<00:00,  2.00batch/s, Train Loss=1.99]
Validation 8/100: 100%|██████████| 30/30 [00:14<00:00,  2.08batch/s, Validation Loss=1.59e+3]


Epoch 8, Training Loss: 1.9908506308913863, Validation Loss: 1589.9158057490984
Learning rate: 0.000100
EarlyStopping counter: 1 out of 10


Epoch 9/100: 100%|██████████| 567/567 [04:43<00:00,  2.00batch/s, Train Loss=2.01]
Validation 9/100: 100%|██████████| 30/30 [00:14<00:00,  2.08batch/s, Validation Loss=14.4]


Epoch 9, Training Loss: 2.013128630908919, Validation Loss: 14.413171581427257
Saving best model with validation loss-------------------------------------> 14.413171581427257
Learning rate: 0.000100
Validation loss decreased (47.6621 -------> 14.4132).  Saving model ...


Epoch 10/100: 100%|██████████| 567/567 [04:49<00:00,  1.96batch/s, Train Loss=1.99]
Validation 10/100: 100%|██████████| 30/30 [00:16<00:00,  1.78batch/s, Validation Loss=1.04e+3]


Epoch 10, Training Loss: 1.9878731366818545, Validation Loss: 1036.3996162931123
Learning rate: 0.000100
EarlyStopping counter: 1 out of 10


Epoch 11/100: 100%|██████████| 567/567 [05:28<00:00,  1.73batch/s, Train Loss=1.99]
Validation 11/100: 100%|██████████| 30/30 [00:16<00:00,  1.86batch/s, Validation Loss=1.25e+3]


Epoch 11, Training Loss: 1.9904584260213942, Validation Loss: 1247.7036024570466
Learning rate: 0.000100
EarlyStopping counter: 2 out of 10


Epoch 12/100: 100%|██████████| 567/567 [05:33<00:00,  1.70batch/s, Train Loss=1.96]
Validation 12/100: 100%|██████████| 30/30 [00:18<00:00,  1.63batch/s, Validation Loss=828]    


Epoch 12, Training Loss: 1.962220973859178, Validation Loss: 827.702973083655
Learning rate: 0.000100
EarlyStopping counter: 3 out of 10


Epoch 13/100: 100%|██████████| 567/567 [05:43<00:00,  1.65batch/s, Train Loss=2.01]
Validation 13/100: 100%|██████████| 30/30 [00:16<00:00,  1.87batch/s, Validation Loss=1.38e+3]


Epoch 13, Training Loss: 2.012135513455359, Validation Loss: 1377.973113612334
Learning rate: 0.000100
EarlyStopping counter: 4 out of 10


Epoch 14/100: 100%|██████████| 567/567 [05:14<00:00,  1.80batch/s, Train Loss=2.01]
Validation 14/100: 100%|██████████| 30/30 [00:16<00:00,  1.87batch/s, Validation Loss=834]    


Epoch 14, Training Loss: 2.0086981247341824, Validation Loss: 834.0124917745591
Learning rate: 0.000100
EarlyStopping counter: 5 out of 10


Epoch 15/100: 100%|██████████| 567/567 [05:15<00:00,  1.80batch/s, Train Loss=1.97]
Validation 15/100: 100%|██████████| 30/30 [00:16<00:00,  1.86batch/s, Validation Loss=3.33e+3]


Epoch 15, Training Loss: 1.970349653898303, Validation Loss: 3332.34144598643
Epoch 00015: reducing learning rate of group 0 to 1.0000e-05.
Learning rate: 0.000010
EarlyStopping counter: 6 out of 10


Epoch 16/100: 100%|██████████| 567/567 [05:08<00:00,  1.84batch/s, Train Loss=1.99]
Validation 16/100: 100%|██████████| 30/30 [00:14<00:00,  2.08batch/s, Validation Loss=784]    


Epoch 16, Training Loss: 1.986466644302247, Validation Loss: 784.4091815948486
Learning rate: 0.000010
EarlyStopping counter: 7 out of 10


Epoch 17/100: 100%|██████████| 567/567 [05:27<00:00,  1.73batch/s, Train Loss=1.98]
Validation 17/100: 100%|██████████| 30/30 [00:18<00:00,  1.67batch/s, Validation Loss=617]    


Epoch 17, Training Loss: 1.9831410463524874, Validation Loss: 616.9310113946597
Learning rate: 0.000010
EarlyStopping counter: 8 out of 10


Epoch 18/100: 100%|██████████| 567/567 [05:33<00:00,  1.70batch/s, Train Loss=2]   
Validation 18/100: 100%|██████████| 30/30 [00:16<00:00,  1.86batch/s, Validation Loss=1.19e+3]


Epoch 18, Training Loss: 1.9977051919520006, Validation Loss: 1191.4472035129866
Learning rate: 0.000010
EarlyStopping counter: 9 out of 10


Epoch 19/100: 100%|██████████| 567/567 [05:55<00:00,  1.60batch/s, Train Loss=2]   
Validation 19/100: 100%|██████████| 30/30 [00:15<00:00,  1.89batch/s, Validation Loss=697]    


Epoch 19, Training Loss: 1.9971257395634996, Validation Loss: 697.3010773181916
Learning rate: 0.000010
EarlyStopping counter: 10 out of 10
Early stopping


: 

In [None]:
from sklearn.model_selection import KFold
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

# Define the dataset and model (assumed to be already defined)
# dataset = ...
# student_model = ...
# sam_encoder = ...

# Set parameters
k_folds = 5
num_epochs = 100
batch_size = 32
learning_rate = 0.001

# Set up the KFold cross-validator
kfold = KFold(n_splits=k_folds, shuffle=True)

# Set up EarlyStopping
early_stopping = EarlyStopping(patience=10, verbose=True, path='resnet50_early_stop.pth')

# Loss and optimizer
criterion = CombinedDistillationLoss(sam_encoder, layers).to(device)
optimizer = optim.SGD(student_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

# K-Fold Cross-Validation
for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
    print(f'Fold {fold + 1}/{k_folds}')
    
    # Sample elements randomly from a given list of indices, no replacement
    train_subsampler = Subset(dataset, train_idx)
    val_subsampler = Subset(dataset, val_idx)

    # Define data loaders for training and validation
    train_loader = DataLoader(train_subsampler, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subsampler, batch_size=batch_size, shuffle=False)

    best_val_loss = float('inf')
    student_model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        train_loader_tqdm = tqdm(train_loader, desc=f"Fold {fold + 1}, Epoch {epoch + 1}/{num_epochs}", unit="batch")
        
        # Training phase
        for images, _ in train_loader_tqdm:
            images = images.to(device)

            optimizer.zero_grad()
            student_outputs = student_model(images)
            loss = criterion(student_outputs, images)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_loader_tqdm.set_postfix({"Train Loss": running_loss / (train_loader_tqdm.n + 1)})

        avg_train_loss = running_loss / len(train_loader)  # Average loss per batch
        print(f'Epoch {epoch + 1}, Training Loss: {avg_train_loss}')

        val_running_loss = 0.0
        student_model.eval()
        val_loader_tqdm = tqdm(val_loader, desc=f"Fold {fold + 1}, Validation {epoch + 1}/{num_epochs}", unit="batch")

        # Validation phase
        with torch.no_grad():
            for val_images, _ in val_loader_tqdm:
                val_images = val_images.to(device)

                student_outputs = student_model(val_images)
                loss = criterion(student_outputs, val_images)

                val_running_loss += loss.item()
                val_loader_tqdm.set_postfix({"Validation Loss": val_running_loss / (val_loader_tqdm.n + 1)})

        avg_val_loss = val_running_loss / len(val_loader)  # Average loss per batch
        print(f'Epoch {epoch + 1}, Validation Loss: {avg_val_loss}')

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(student_model.state_dict(), f'resnet50_best_fold{fold + 1}.pth')
            print(f'Saving best model with validation loss: {best_val_loss}')

        torch.save(student_model.state_dict(), f'resnet50_latest_fold{fold + 1}.pth')

        # Early stopping
        early_stopping(avg_val_loss, student_model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

        student_model.train()

print('K-Fold Cross-Validation complete.')
