In [1]:
!pip install torchsummary wandb

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
from torchsummary import summary
import wandb


config = {
    'batch_size': 64,
    'lr': 0.001,
    'epochs': 100,
    'data_dir': "/kaggle/input/mydataset/dataset/EM_processed/training",
    'data_label_dir': "/kaggle/input/mydataset/dataset/Label_processed/training",
    'checkpoint_dir': "/kaggle/working/",
#     'val_pair_dir': "/content/data/11-785-f24-hw2p2-verification/val_pairs.txt",
#     'test_pair_dir': "/content/data/11-785-f24-hw2p2-verification/test_pairs.txt"
}
    
    
wandb.login(key="98642f33baa53793f08e5f32f1d09da8c7c6e80b")
run = wandb.init(
    reinit=True,
    project="CASENet",
    config=config
)

# Custom Dataset Class
class BWImageDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        super(BWImageDataset, self).__init__()
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform

        self.image_names = sorted(os.listdir(image_dir))
        self.label_names = sorted(os.listdir(label_dir))
        assert len(self.image_names) == len(self.label_names), "Mismatch between images and labels"

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        # Load image and label
        image = Image.open(os.path.join(self.image_dir, self.image_names[idx])).convert('L')
        label = Image.open(os.path.join(self.label_dir, self.label_names[idx])).convert('L')

        if self.transform is not None:
            image = self.transform(image)
            label = self.transform(label)
        else:
            image = transforms.ToTensor()(image)
            label = transforms.ToTensor()(label)

        # Normalize label to be in [0,1]
        label = (label > 0).float()

        return image, label

# Function to modify ResNet to accept single-channel input
def get_resnet_backbone(backbone_name='resnet50', pretrained=True):
    if backbone_name == 'resnet50':
        backbone = models.resnet50(pretrained=pretrained)
    elif backbone_name == 'resnet101':
        backbone = models.resnet101(pretrained=pretrained)
    else:
        raise ValueError('Unsupported backbone {}'.format(backbone_name))

    # Modify the first convolution layer to accept 1-channel input
    backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    # If pretrained, adjust weights (optional)
    if pretrained:
        original_conv1 = models.resnet50(pretrained=True).conv1
        backbone.conv1.weight.data = original_conv1.weight.data.sum(dim=1, keepdim=True)
    return backbone

# BaseNet Class (modified to accept 1-channel input)
class BaseNet(nn.Module):
    def __init__(self, backbone_name='resnet50'):
        super(BaseNet, self).__init__()
        self.backbone = get_resnet_backbone(backbone_name, pretrained=True)

    def forward(self, x):
        # Forward through ResNet backbone
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        return x

# Segmentation Head for Grayscale Output
class GrayscaleSegmentationHead(nn.Module):
    def __init__(self, in_channels):
        super(GrayscaleSegmentationHead, self).__init__()
        self.conv = nn.Conv2d(in_channels, 1, kernel_size=1)

    def forward(self, x):
        x = self.conv(x)
        # Upsample to match input size (assuming input size divisible by 32)
        x = nn.functional.interpolate(x, scale_factor=32, mode='bilinear', align_corners=False)
        # Output is continuous, no activation here since BCEWithLogitsLoss expects logits
        return x

# Complete Model
class SegmentationModel(nn.Module):
    def __init__(self, backbone_name='resnet50'):
        super(SegmentationModel, self).__init__()
        self.base_net = BaseNet(backbone_name)
        self.seg_head = GrayscaleSegmentationHead(2048)  # ResNet-50 outputs 2048 features

    def forward(self, x):
        x = self.base_net(x)
        x = self.seg_head(x)
        return x

# Training Function
from tqdm import tqdm

def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch}', leave=False)
    
    for batch_idx, (data, target) in progress_bar:
        data = data.to(device, dtype=torch.float32)
        target = target.to(device, dtype=torch.float32)  # For BCEWithLogitsLoss, target should be float
        optimizer.zero_grad()
        output = model(data)
        # Output shape: [B, 1, H, W], Target shape: [B, 1, H, W]
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        metrics = {
            'train_loss': loss,
#             'val_loss': val_loss,
        }
        
        run.log(metrics)
        
        running_loss += loss.item()
        
        progress_bar.set_postfix(loss=loss.item())

    print('Training Loss after epoch {}: {:.6f}'.format(epoch, running_loss / len(train_loader)))


# Main Function
def main():
    # Transforms
    transform = transforms.Compose([
        # transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    # Dataset and DataLoader
    train_dataset = BWImageDataset(config['data_dir'], config['data_label_dir'], transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

    # Model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SegmentationModel(backbone_name='resnet50').to(device)

    summary(model,(1,256,256))

    # Optimizer and Loss Function
    optimizer = optim.Adam(model.parameters(), lr=config['lr'])
    criterion = nn.BCEWithLogitsLoss()

    # Training Loop
    for epoch in range(1, config['epochs'] + 1):
        train(model, device, train_loader, optimizer, criterion, epoch)

    # Save the model
    torch.save(model.state_dict(), os.path.join(config['checkpoint_dir'],'segmentation_model.pth'))
    print('Model saved as segmentation_model.pth')

if __name__ == '__main__':
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mseitomoyi[0m ([33mseitomoyi-carnegie-mellon-university[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.18.3
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20241026_044513-xt86nszx[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mgenerous-wildflower-4[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/seitomoyi-carnegie-mellon-university/CASENet[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/seitomoyi-carnegie-mellon-university/CASENet/runs/

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           3,136
       BatchNorm2d-2         [-1, 64, 128, 128]             128
              ReLU-3         [-1, 64, 128, 128]               0
         MaxPool2d-4           [-1, 64, 64, 64]               0
            Conv2d-5           [-1, 64, 64, 64]           4,096
       BatchNorm2d-6           [-1, 64, 64, 64]             128
              ReLU-7           [-1, 64, 64, 64]               0
            Conv2d-8           [-1, 64, 64, 64]          36,864
       BatchNorm2d-9           [-1, 64, 64, 64]             128
             ReLU-10           [-1, 64, 64, 64]               0
           Conv2d-11          [-1, 256, 64, 64]          16,384
      BatchNorm2d-12          [-1, 256, 64, 64]             512
           Conv2d-13          [-1, 256, 64, 64]          16,384
      BatchNorm2d-14          [-1, 256,

                                                                      

Training Loss after epoch 1: 0.121132


                                                                      

Training Loss after epoch 2: 0.088650


                                                                       

Training Loss after epoch 3: 0.080725


                                                                       

Training Loss after epoch 4: 0.076005


                                                                       

Training Loss after epoch 5: 0.076377


                                                                       

Training Loss after epoch 6: 0.066081


                                                                       

Training Loss after epoch 7: 0.063836


                                                                       

Training Loss after epoch 8: 0.060089


                                                                       

Training Loss after epoch 9: 0.052383


                                                                        

Training Loss after epoch 10: 0.050044


                                                                        

Training Loss after epoch 11: 0.042522


                                                                        

Training Loss after epoch 12: 0.037193


                                                                         

Training Loss after epoch 13: 0.034212


                                                                        

Training Loss after epoch 14: 0.029344


                                                                         

Training Loss after epoch 15: 0.026449


                                                                        

Training Loss after epoch 16: 0.022783


                                                                        

Training Loss after epoch 17: 0.020375


                                                                         

Training Loss after epoch 18: 0.019138


                                                                         

Training Loss after epoch 19: 0.019022


                                                                        

Training Loss after epoch 20: 0.018863


                                                                        

Training Loss after epoch 21: 0.020086


                                                                         

Training Loss after epoch 22: 0.019523


                                                                         

Training Loss after epoch 23: 0.016010


                                                                        

Training Loss after epoch 24: 0.014061


                                                                       

Training Loss after epoch 25: 0.011982


                                                                         

Training Loss after epoch 26: 0.011407


                                                                         

Training Loss after epoch 27: 0.010282


                                                                        

Training Loss after epoch 28: 0.011862


                                                                       

Training Loss after epoch 29: 0.031071


                                                                        

Training Loss after epoch 30: 0.025613


                                                                       

Training Loss after epoch 31: 0.016086


                                                                         

Training Loss after epoch 32: 0.011416


                                                                        

Training Loss after epoch 33: 0.010063


                                                                         

Training Loss after epoch 34: 0.009052


                                                                        

Training Loss after epoch 35: 0.007985


                                                                         

Training Loss after epoch 36: 0.007651


                                                                         

Training Loss after epoch 37: 0.007861


                                                                        

Training Loss after epoch 38: 0.007541


                                                                        

Training Loss after epoch 39: 0.007241


                                                                        

Training Loss after epoch 40: 0.006667


                                                                         

Training Loss after epoch 41: 0.006807


                                                                         

Training Loss after epoch 42: 0.007776


                                                                         

Training Loss after epoch 43: 0.008168


                                                                        

Training Loss after epoch 44: 0.034171


                                                                        

Training Loss after epoch 45: 0.029987


                                                                        

Training Loss after epoch 46: 0.015173


                                                                         

Training Loss after epoch 47: 0.010269


                                                                         

Training Loss after epoch 48: 0.007982


                                                                          

Training Loss after epoch 49: 0.006843


                                                                         

Training Loss after epoch 50: 0.006285


                                                                         

Training Loss after epoch 51: 0.005986


                                                                         

Training Loss after epoch 52: 0.006022


                                                                        

Training Loss after epoch 53: 0.005772


                                                                         

Training Loss after epoch 54: 0.005600


                                                                         

Training Loss after epoch 55: 0.005187


                                                                         

Training Loss after epoch 56: 0.005269


                                                                         

Training Loss after epoch 57: 0.005138


                                                                         

Training Loss after epoch 58: 0.005141


                                                                         

Training Loss after epoch 59: 0.005050


                                                                         

Training Loss after epoch 60: 0.004904


                                                                         

Training Loss after epoch 61: 0.005492


                                                                         

Training Loss after epoch 62: 0.005289


                                                                        

Training Loss after epoch 63: 0.046071


                                                                         

Training Loss after epoch 64: 0.023580


                                                                        

Training Loss after epoch 65: 0.011983


                                                                         

Training Loss after epoch 66: 0.007962


                                                                         

Training Loss after epoch 67: 0.006235


                                                                         

Training Loss after epoch 68: 0.005496


                                                                         

Training Loss after epoch 69: 0.005279


                                                                        

Training Loss after epoch 70: 0.005195


                                                                         

Training Loss after epoch 71: 0.004950


                                                                        

Training Loss after epoch 72: 0.004817


                                                                         

Training Loss after epoch 73: 0.004967


                                                                        

Training Loss after epoch 74: 0.004733


                                                                          

Training Loss after epoch 75: 0.004595


                                                                         

Training Loss after epoch 76: 0.007537


                                                                        

Training Loss after epoch 77: 0.026396


                                                                         

Training Loss after epoch 78: 0.013720


                                                                         

Training Loss after epoch 79: 0.010611


                                                                         

Training Loss after epoch 80: 0.006515


                                                                         

Training Loss after epoch 81: 0.005331


                                                                         

Training Loss after epoch 82: 0.004928


                                                                         

Training Loss after epoch 83: 0.004612


                                                                         

Training Loss after epoch 84: 0.004558


                                                                        

Training Loss after epoch 85: 0.004439


                                                                         

Training Loss after epoch 86: 0.004285


                                                                         

Training Loss after epoch 87: 0.004593


                                                                          

Training Loss after epoch 88: 0.004561


                                                                       

Training Loss after epoch 89: 0.004382


                                                                         

Training Loss after epoch 90: 0.004437


                                                                        

Training Loss after epoch 91: 0.025424


                                                                       

Training Loss after epoch 92: 0.014293


                                                                         

Training Loss after epoch 93: 0.008125


                                                                         

Training Loss after epoch 94: 0.005938


                                                                        

Training Loss after epoch 95: 0.004891


                                                                        

Training Loss after epoch 96: 0.004427


                                                                         

Training Loss after epoch 97: 0.004396


                                                                         

Training Loss after epoch 98: 0.004291


                                                                        

Training Loss after epoch 99: 0.004262


                                                                          

Training Loss after epoch 100: 0.004275
Model saved as segmentation_model.pth


