In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import os
import numpy as np
import torch.nn.functional as F
import mlflow
import mlflow.pytorch
from tqdm.notebook import tqdm

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
class ClassificationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_to_idx = {
            'glioma': 0, 
            'meningioma': 1, 
            'notumor': 2, 
            'pituitary': 3
        }
        for class_name in os.listdir(root_dir):
            class_dir = os.path.join(root_dir, class_name)
            if os.path.isdir(class_dir) and class_name in self.class_to_idx:
                for img_name in os.listdir(class_dir):
                    img_path = os.path.join(class_dir, img_name)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])
                        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label
    
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
            
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Convert mask to grayscale
        if self.transform:
            image, mask = self.transform(image, mask)
        return image, mask

In [4]:
# class ResidualInceptionBlock(nn.Module):
#     def __init__(self, in_channels, out_channels, stride=1):
#         super(ResidualInceptionBlock, self).__init__()
#         self.branch1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
        
#         self.branch2 = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
#         )
        
#         self.branch3 = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size=5, stride=stride, padding=2)
#         )
        
#         self.branch_pool = nn.Sequential(
#             nn.MaxPool2d(kernel_size=3, stride=stride, padding=1),
#             nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
#         )
        
#         self.bn = nn.BatchNorm2d(4 * out_channels)
#         self.relu = nn.ReLU(inplace=True)
        
#         # Residual connection
#         self.shortcut = nn.Sequential()
#         if stride != 1 or in_channels != 4 * out_channels:
#             self.shortcut = nn.Sequential(
#                 nn.Conv2d(in_channels, 4 * out_channels, kernel_size=1, stride=stride),
#                 nn.BatchNorm2d(4 * out_channels)
#             )
        
#     def forward(self, x):
#         branch1 = self.branch1(x)
        
#         branch2 = self.branch2(x)
        
#         branch3 = self.branch3(x)
        
#         branch4 = self.branch_pool(x)
        
#         outputs = torch.cat([branch1, branch2, branch3, branch4], 1)
#         outputs = self.bn(outputs)
        
#         residual = self.shortcut(x)
#         outputs += residual
#         outputs = self.relu(outputs)
#         return outputs

# class DetectionModel(nn.Module):
#     def __init__(self, num_classes=2):
#         super(DetectionModel, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
#         self.bn1   = nn.BatchNorm2d(64)
#         self.relu  = nn.ReLU(inplace=True)
#         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
#         self.layer1 = ResidualInceptionBlock(64, 64)
#         self.layer2 = ResidualInceptionBlock(256, 128, stride=2)
#         self.layer3 = ResidualInceptionBlock(512, 256, stride=2)
#         self.layer4 = ResidualInceptionBlock(1024, 512, stride=2)
        
#         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
#         self.fc      = nn.Linear(2048, num_classes)
        
#     def forward(self, x):
#         x = self.relu(self.bn1(self.conv1(x)))  # [B, 64, H/2, W/2]
#         x = self.maxpool(x)  # [B, 64, H/4, W/4]
        
#         x = self.layer1(x)   # [B, 256, H/4, W/4]
#         x = self.layer2(x)   # [B, 512, H/8, W/8]
#         x = self.layer3(x)   # [B, 1024, H/16, W/16]
#         x = self.layer4(x)   # [B, 2048, H/32, W/32]
        
#         x = self.avgpool(x)  # [B, 2048, 1, 1]
#         x = x.view(x.size(0), -1)  # [B, 2048]
#         x = self.fc(x)       # [B, num_classes]
#         return x

# # Instantiate the model
# def create_detection_model(num_classes=2):
#     model = DetectionModel(num_classes=num_classes)
#     return model

In [5]:
class ResidualInceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualInceptionBlock, self).__init__()
        # Branch 1
        self.branch1 = nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=stride)

        # Branch 2
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=stride),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // 4, out_channels // 4, kernel_size=3, stride=1, padding=1)
        )

        # Branch 3
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=stride),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // 4, out_channels // 4, kernel_size=5, stride=1, padding=2)
        )

        # Branch 4 (Pooling)
        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=stride, padding=1),
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=1, padding=0)
        )

        # Batch Norm and Activation
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # Shortcut Connection
        self.shortcut = nn.Sequential()
        if in_channels != out_channels or stride != 1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch_pool(x)
        outputs = torch.cat([branch1, branch2, branch3, branch4], 1)
        outputs = self.bn(outputs)
        residual = self.shortcut(x)
        outputs += residual
        outputs = self.relu(outputs)
        return outputs
    
def make_layer(block, in_channels, out_channels, num_blocks, stride=1):
    layers = []
    layers.append(block(in_channels, out_channels, stride=stride))
    in_channels = out_channels  # Update in_channels for subsequent blocks
    for _ in range(1, num_blocks):
        layers.append(block(in_channels, out_channels, stride=1))
    return nn.Sequential(*layers)


# class DeeperDetectionModel(nn.Module):
#     def __init__(self, num_classes=2):
#         super(DeeperDetectionModel, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
#         self.bn1   = nn.BatchNorm2d(64)
#         self.relu  = nn.ReLU(inplace=True)
#         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

#         # Define the number of blocks in each layer
#         self.layer1 = make_layer(ResidualInceptionBlock, 64, 128, num_blocks=3, stride=1)
#         self.layer1_5 = make_layer(ResidualInceptionBlock, 128, 256, num_blocks=3, stride=1)
#         self.layer2 = make_layer(ResidualInceptionBlock, 256, 512, num_blocks=4, stride=2)
#         # self.layer2_5 = make_layer(ResidualInceptionBlock, 512, 512, num_blocks=4, stride=1)
#         self.layer2_5 = make_layer(ResidualInceptionBlock, 512, 512, num_blocks=6, stride=1)
#         self.layer3 = make_layer(ResidualInceptionBlock, 512, 1024, num_blocks=6, stride=2)
#         self.layer4 = make_layer(ResidualInceptionBlock, 1024, 2048, num_blocks=3, stride=2)

#         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
#         self.fc      = nn.Linear(2048, num_classes)

#     def forward(self, x):
#         x = self.relu(self.bn1(self.conv1(x)))  # [B, 64, H/2, W/2]
#         x = self.maxpool(x)                     # [B, 64, H/4, W/4]

#         x = self.layer1(x)                      # [B, 256, H/4, W/4]
#         x = self.layer1_5(x)                    # [B, 256, H/4, W/4]
#         x = self.layer2(x)                      # [B, 512, H/8, W/8]
#         x = self.layer2_5(x)                    # [B, 512, H/8, W/8]
#         x = self.layer3(x)                      # [B, 1024, H/16, W/16]
#         x = self.layer4(x)                      # [B, 2048, H/32, W/32]

#         x = self.avgpool(x)                     # [B, 2048, 1, 1]
#         x = x.view(x.size(0), -1)               # [B, 2048]
#         x = self.fc(x)                          # [B, num_classes]
#         return x
    
class DeeperDetectionModel(nn.Module):
    def __init__(self, num_classes=2):
        super(DeeperDetectionModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1   = nn.BatchNorm2d(64)
        self.relu  = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Define the number of blocks in each layer
        self.layer1 = make_layer(ResidualInceptionBlock, 64, 256, num_blocks=3, stride=1)
        self.layer2 = make_layer(ResidualInceptionBlock, 256, 512, num_blocks=4, stride=2)
        self.layer3 = make_layer(ResidualInceptionBlock, 512, 1024, num_blocks=6, stride=2)
        self.layer4 = make_layer(ResidualInceptionBlock, 1024, 2048, num_blocks=3, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc      = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))  # [B, 64, H/2, W/2]
        x = self.maxpool(x)                     # [B, 64, H/4, W/4]

        x = self.layer1(x)                      # [B, 256, H/4, W/4]
        x = self.layer2(x)                      # [B, 512, H/8, W/8]
        x = self.layer3(x)                      # [B, 1024, H/16, W/16]
        x = self.layer4(x)                      # [B, 2048, H/32, W/32]

        x = self.avgpool(x)                     # [B, 2048, 1, 1]
        x = x.view(x.size(0), -1)               # [B, 2048]
        x = self.fc(x)                          # [B, num_classes]
        return x

# Instantiate the model
def create_detection_model(num_classes=2):
    model = DeeperDetectionModel(num_classes=num_classes)
    return model

In [6]:
class ResidualInceptionBlock_segmentation(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualInceptionBlock_segmentation, self).__init__()
        self.branch1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
        
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        )
        
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=5, stride=stride, padding=2)
        )
        
        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=stride, padding=1),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
        )
        
        self.bn = nn.BatchNorm2d(4 * out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        # Residual connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != 4 * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, 4 * out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(4 * out_channels)
            )
        
    def forward(self, x):
        branch1 = self.branch1(x)
        
        branch2 = self.branch2(x)
        
        branch3 = self.branch3(x)
        
        branch4 = self.branch_pool(x)
        
        outputs = torch.cat([branch1, branch2, branch3, branch4], 1)
        outputs = self.bn(outputs)
        
        residual = self.shortcut(x)
        outputs += residual
        outputs = self.relu(outputs)
        return outputs

class UpBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super(UpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.rib = ResidualInceptionBlock_segmentation(out_channels + skip_channels, out_channels, stride=1)
        
    def forward(self, x, skip_connection):
        x = self.up(x)  # Upsample
        x = torch.cat([x, skip_connection], dim=1)  # Concatenate with skip connection
        x = self.rib(x)  # Pass through Residual Inception Block
        return x

class SegmentationModel(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(SegmentationModel, self).__init__()
        self.enc1 = ResidualInceptionBlock_segmentation(in_channels, 32)  # Output channels: 128
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = ResidualInceptionBlock_segmentation(128, 64)  # Output channels: 256
        self.pool2 = nn.MaxPool2d(2)
        
        self.enc3 = ResidualInceptionBlock_segmentation(256, 128)  # Output channels: 512
        self.pool3 = nn.MaxPool2d(2)
        
        self.enc4 = ResidualInceptionBlock_segmentation(512, 256)  # Output channels: 1024
        self.pool4 = nn.MaxPool2d(2)
        
        self.bottleneck = ResidualInceptionBlock_segmentation(1024, 512)  # Output channels: 2048
        
        # Decoder with corrected channel dimensions
        self.up4 = UpBlock(in_channels=2048, skip_channels=1024, out_channels=256)
        self.up3 = UpBlock(in_channels=1024, skip_channels=512, out_channels=128)
        self.up2 = UpBlock(in_channels=512, skip_channels=256, out_channels=64)
        self.up1 = UpBlock(in_channels=256, skip_channels=128, out_channels=32)
        
        self.final_conv = nn.Conv2d(128, out_channels, kernel_size=1)
        
    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)  # [B, 128, H, W]
        enc2 = self.enc2(self.pool1(enc1))  # [B, 256, H/2, W/2]
        enc3 = self.enc3(self.pool2(enc2))  # [B, 512, H/4, W/4]
        enc4 = self.enc4(self.pool3(enc3))  # [B, 1024, H/8, W/8]
        
        # Bottleneck
        bottleneck = self.bottleneck(self.pool4(enc4))  # [B, 2048, H/16, W/16]
        
        # Decoder
        dec4 = self.up4(bottleneck, enc4)  # [B, 1024, H/8, W/8]
        dec3 = self.up3(dec4, enc3)        # [B, 512, H/4, W/4]
        dec2 = self.up2(dec3, enc2)        # [B, 256, H/2, W/2]
        dec1 = self.up1(dec2, enc1)        # [B, 128, H, W]
        
        out = self.final_conv(dec1)        # [B, out_channels, H, W]
        return out

# Instantiate the model
def create_segmentation_model(in_channels=3, out_channels=1):
    model = SegmentationModel(in_channels=in_channels, out_channels=out_channels)
    return model

In [7]:
classification_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

class SegmentationTransform:
    def __init__(self):
        self.image_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ])
            
    def __call__(self, image, mask):
        image = self.image_transform(image)
        mask = mask.resize((224, 224), resample=Image.NEAREST)
        mask = np.array(mask, dtype=np.float32)
        mask = np.expand_dims(mask, axis=0)  # Add channel dimension
        mask = torch.from_numpy(mask)
        mask = mask / 255.0  # Scale mask to [0, 1] range if needed
        return image, mask

train_dataset = ClassificationDataset(root_dir='./Dataset_Arghadip/Train', transform=classification_transform)
val_dataset = ClassificationDataset(root_dir='./Dataset_Arghadip/Validation', transform=classification_transform)
test_dataset = ClassificationDataset(root_dir='./Dataset_Arghadip/Test', transform=classification_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

segmentation_image_dir = './Segmentaion_dataset/images'
segmentation_mask_dir = './Segmentaion_dataset/masks'

segmentation_transform = SegmentationTransform()
segmentation_dataset = SegmentationDataset(image_dir=segmentation_image_dir,
                                           mask_dir=segmentation_mask_dir,
                                           transform=segmentation_transform)

segmentation_loader = DataLoader(segmentation_dataset, batch_size=8, shuffle=True)

In [8]:
# classification_model = models.resnet18(pretrained=True)
# num_ftrs = classification_model.fc.in_features
# classification_model.fc = nn.Linear(num_ftrs, 4)  # 4 classes
# classification_model = classification_model.to(device)

In [9]:
# segmentation_model = models.segmentation.fcn_resnet50(pretrained=False, num_classes=1)
# segmentation_model = segmentation_model.to(device)

In [10]:
# model1 = MICRO_SEM(3)
# model2 = MESO_SEM(3)
# model3 = MACRO_SEM(3)

In [11]:
# from torchinfo import summary
# summary_ = summary(model3, input_size = (32, 3, 224, 224))
# summary_

In [12]:
# segmentation_model = create_segmentation_model(3, 1)
# segmentation_model = segmentation_model.to(device)

In [13]:
classification_model = create_detection_model(num_classes = 4)
classification_model = classification_model.to(device)
segmentation_model = create_segmentation_model(3, 1)
segmentation_model = segmentation_model.to(device)

In [14]:
from torchview import draw_graph
import graphviz as gv
gv.set_jupyter_format('svg')

classification_model_graph = draw_graph(classification_model, input_size = (8, 3, 256, 256), graph_name = 'classification_model', expand_nested = True, save_graph = True, directory = "./Architecture")


In [15]:
segmentation_model_graph = draw_graph(segmentation_model, input_size = (8, 3, 256, 256), graph_name = 'segmentation_model', expand_nested = True, save_graph = True, directory = "./Architecture")

In [16]:
gv.render(engine = 'dot', format = 'svg', filepath = './Architecture/classification_model.gv')
gv.render(engine = 'dot', format = 'pdf', filepath = './Architecture/classification_model.gv')

gv.render(engine = 'dot', format = 'svg', filepath = './Architecture/segmentation_model.gv')
gv.render(engine = 'dot', format = 'pdf', filepath = './Architecture/segmentation_model.gv')

'Architecture\\segmentation_model.gv.pdf'

In [None]:
from torchinfo import summary
summary_ = summary(classification_model, input_size = (16, 3, 256, 256))
summary_

In [18]:
# criterion_segmentation = nn.BCEWithLogitsLoss()
# optimizer_segmentation = optim.Adam(segmentation_model.parameters(), lr=1e-4)

In [19]:
criterion_classification = nn.CrossEntropyLoss()
optimizer_classification = optim.Adam(classification_model.parameters(), lr=1e-4)
# optimizer_classification = optim.AdamW(classification_model.parameters(), lr = 0.001, weight_decay = 0.01, amsgrad = True)

criterion_segmentation = nn.BCEWithLogitsLoss()
optimizer_segmentation = optim.Adam(segmentation_model.parameters(), lr=1e-4)

In [20]:
# num_epochs = 10

# for epoch in range(num_epochs):
#     model.train()
#     running_loss = 0.0
#     running_corrects = 0
    
#     for inputs, labels in train_loader:
#         inputs = inputs.to(device)
#         labels = labels.to(device)
        
#         optimizer.zero_grad()
        
#         outputs = model(inputs)
#         _, preds = torch.max(outputs, 1)
#         loss = criterion(outputs, labels)
        
#         loss.backward()
#         optimizer.step()
        
#         running_loss += loss.item() * inputs.size(0)
#         running_corrects += torch.sum(preds == labels.data)
    
#     epoch_loss = running_loss / len(train_loader)
#     epoch_acc = running_corrects.double() / len(train_loader)
    
#     print('Epoch {}/{}: Loss: {:.4f} Acc: {:.4f}'.format(
#         epoch+1, num_epochs, epoch_loss, epoch_acc))


In [21]:
# mlflow.start_run()
# mlflow.log_param("batch_size", 16)
# mlflow.log_param("learning_rate", 0.001)
# mlflow.log_param("num_epochs", 100)

In [None]:
from tqdm.notebook import tqdm

num_epochs = 100
total_len = len(train_loader) + len(val_loader)

train_losses = np.zeros(num_epochs)
test_losses = np.zeros(num_epochs)
train_acc = np.zeros(num_epochs)
test_acc = np.zeros(num_epochs)

best_test_acc = 0

for epoch in range(num_epochs):
    epoch_str = str(epoch + 1).rjust(len(str(num_epochs)), " ")
    with tqdm(total = total_len, desc = f"Epoch [ {epoch_str}/{num_epochs} ] : ") as pbar:
        
        classification_model.train()
        n_correct = 0
        n_total = 0
        train_loss = []
        
        for j, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer_classification.zero_grad()
            
            outputs = classification_model(inputs)
            loss = criterion_classification(outputs, targets)
            train_loss.append(loss.item())
            
            loss.backward()
            optimizer_classification.step()
            
            _, prediction = torch.max(outputs, 1)
            
            n_correct += (prediction == targets).sum().item()
            n_total += targets.shape[0]
            
            # mlflow.log_metric("train_loss", loss.item(), step=epoch * len(train_loader) + j)
            # mlflow.log_metric("train_accuracy", n_correct / n_total, step=epoch * len(train_loader) + j)
            
            pbar.update(1)
        
        # before_lr = optimizer.param_groups[0]["lr"]
        # scheduler.step()
        # after_lr = optimizer.param_groups[0]["lr"]
        
        train_loss = np.mean(train_loss)
        train_acc_real = n_correct/n_total
        train_losses[epoch] = train_loss
        train_acc[epoch] = train_acc_real
                
        classification_model.eval()
        n_correct = 0
        n_total = 0
        test_loss = []
        
        with torch.no_grad():
            for i, (inputs, targets) in enumerate(val_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                
                outputs = classification_model(inputs)
                
                loss = criterion_classification(outputs, targets)
                test_loss.append(loss.item())
                
                _, prediction = torch.max(outputs, 1)
                
                n_correct += (prediction == targets).sum().item()
                n_total += targets.shape[0]
                
                pbar.update(1)
                
            mlflow.log_metric("valid_loss", loss.item(), step=epoch * len(val_loader) + i)
            mlflow.log_metric("valid_accuracy", n_correct / n_total, step=epoch * len(val_loader) + i)
            
        test_loss = np.mean(test_loss)
        test_acc_real = n_correct/n_total
        test_losses[epoch] = test_loss
        test_acc[epoch] = test_acc_real
        
        pbar.set_description(f"Epoch [ {epoch_str}/{num_epochs} ] ")
        pbar.set_postfix({'Train acc' : f'{train_acc_real:.3f}',
                          'Train loss' : f'{train_loss:.3f}',
                          'Test acc' : f'{test_acc_real:.3f}',
                          'Test loss': f'{test_loss:.3f}'
                        #   'lr_before' : f'{before_lr}',
                        #   'lr_after' : f'{after_lr}'
                         })
        
        # mlflow.log_metric("train_loss_epoch", train_loss, step=epoch)
        # mlflow.log_metric("train_accuracy_epoch", train_acc_real, step=epoch)
        # mlflow.log_metric("valid_loss_epoch", test_loss, step=epoch)
        # mlflow.log_metric("valid_accuracy_epoch", test_acc_real, step=epoch)
            
        if test_acc_real >= best_test_acc:
            best_test_acc = test_acc_real
            test_str = f"{int(test_acc_real * 1000)/10}"
            train_str = f"{int(train_acc_real * 1000)/10}"
            torch.save(classification_model.state_dict(), f"./Models/Retrain/classification_model_train_{epoch}_{train_str}_test_{test_str}.pth")

In [None]:
# mlflow.pytorch.log_model(classification_model, "model")
# mlflow.end_run()

In [36]:
test_dataset = ClassificationDataset(root_dir='./New_datasets/archive_2/Testing', transform=classification_transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
model_trained = create_detection_model(num_classes = 4)
model_trained.load_state_dict(torch.load("./Models/classification_model_train_41_99.8_test_99.0.pth", weights_only=True))
model_trained.to(device)
model_trained.eval()

val_running_corrects = 0

n_correct = 0
n_total = 0

n_correct_ = 0
n_total_ = 0

y_prob = []
y_pred = []
y_true = []

torch.cuda.empty_cache()

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)
                
        outputs = model_trained(inputs)
        probs_ = torch.softmax(outputs, dim = 1)
        _, prediction_ = torch.max(outputs, 1)
        prediction__ = (probs_ >= 0.5) * 1
        
        probs_cpu = probs_.cpu().detach().numpy()
        preds_cpu = prediction_.cpu().detach().numpy()
        targets_cpu = targets.cpu().detach().numpy()
        
        
        for element in probs_cpu:
            y_prob.append(element)
    
        for element in preds_cpu:
            y_pred.append(element)
        
        for element in targets_cpu:
            y_true.append(element)
                
        n_correct += (prediction_ == targets).sum().item()
        n_total += targets.shape[0]

test_acc_real = n_correct/n_total
print('Test Accuracy: {:.4f}'.format(test_acc_real))

In [25]:
y_prob_np = np.array(y_prob)
y_pred_np = np.array(y_pred).reshape(-1, 1)
y_true_np = np.array(y_true).reshape(-1, 1)

In [None]:
y_true_np

In [27]:
from sklearn.preprocessing import label_binarize
y_true_bin = label_binarize(y_true_np, classes=[0, 1, 2, 3])

In [28]:
from sklearn.metrics import r2_score, roc_auc_score, precision_score, confusion_matrix, f1_score, accuracy_score, recall_score, roc_curve, auc

cm = confusion_matrix(y_true_np, y_pred_np)
r2 = r2_score(y_true_np, y_pred_np)
ras = roc_auc_score(y_true_np, y_prob_np, multi_class = 'ovr')
ps = precision_score(y_true_np, y_pred_np, average = 'weighted')
f1 = f1_score(y_true_np, y_pred_np, average = 'weighted')
acc = accuracy_score(y_true_np, y_pred_np)
rs = recall_score(y_true_np, y_pred_np, average = 'weighted')
# fpr, tpr, _ = roc_curve(y_true_np, y_prob_np)

fpr = dict()
tpr = dict()
roc_auc = dict()

n_classes = 4
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_prob_np[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

In [None]:
print("Confusion Matrix : ")
print(cm, "\n")
print("R2 Score : ", r2)
print("ROC AUC Score : ", ras)
print("Precision Score : ", ps)
print("F1 Score : ", f1)
print("Accuracy : ", acc)
print("Recall Score : ", rs)

In [30]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score

In [None]:
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_prob_np.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Plot ROC curve for each class with improved aesthetics
plt.figure(figsize=(10, 8))

# Use a color map to make it more visually appealing
colors = ['#4379F2', '#604CC3', '#6EC207', '#FF6600']

for i in range(n_classes):
    plt.plot(fpr[i], tpr[i], color=colors[i], lw=2, 
             label=f'Class {i} (AUC = {roc_auc[i]:0.2f})')

# Plot the micro-average ROC curve
plt.plot(fpr["micro"], tpr["micro"], color='#1A3636', linestyle='dashed', linewidth=2,
         label=f'Micro-average ROC curve (AUC = {roc_auc["micro"]:0.2f})')

# Diagonal line representing the random classifier
plt.plot([0, 1], [0, 1], 'k--', lw=2)

# Customize the plot aesthetics
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.title('Multiclass ROC Curve', fontsize=18, fontweight='bold')
plt.legend(loc="lower right", fontsize=12)
plt.grid(True, linestyle='--', alpha=0.6)

# Show plot
plt.show()

# If you want the macro-average AUC score across all classes
roc_auc_macro = roc_auc_score(y_true_bin, y_prob_np, average="macro", multi_class="ovr")
print(f"Macro-average AUC: {roc_auc_macro:.2f}")

In [None]:
import plotly.graph_objects as go

fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_prob_np.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Create the plot
fig = go.Figure()

# Colors for each class
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

# Add ROC curves for each class
for i in range(n_classes):
    fig.add_trace(go.Scatter(x=fpr[i], y=tpr[i], mode='lines', 
                             name=f'Class {i} (AUC = {roc_auc[i]:0.2f})',
                             line=dict(color=colors[i], width=3)))

# Add micro-average ROC curve
fig.add_trace(go.Scatter(x=fpr["micro"], y=tpr["micro"], mode='lines', 
                         name=f'Micro-average ROC (AUC = {roc_auc["micro"]:0.2f})',
                         line=dict(color='deeppink', width=4, dash='dot')))

# Add diagonal line for random classifier
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', 
                         name='Random Classifier', line=dict(color='black', dash='dash')))

# Update layout for better aesthetics
fig.update_layout(
    title="Multiclass ROC Curve",
    xaxis_title="False Positive Rate",
    yaxis_title="True Positive Rate",
    font=dict(size=14),
    width=800,
    height=600,
    legend=dict(x=0.8, y=0.2),
    plot_bgcolor='white'
)

# Add gridlines for clarity
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

# Show the plot
fig.show()

# If you want the macro-average AUC score across all classes
roc_auc_macro = roc_auc_score(y_true_bin, y_prob_np, average="macro", multi_class="ovr")
print(f"Macro-average AUC: {roc_auc_macro:.2f}")


In [None]:
classification_model.eval()
val_running_corrects = 0

n_correct = 0
n_total = 0

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
                
        outputs = classification_model(inputs)

        _, prediction = torch.max(outputs, 1)
                
        n_correct += (prediction == targets).sum().item()
        n_total += targets.shape[0]

test_acc_real = n_correct/n_total
print('Validation Accuracy: {:.4f}'.format(test_acc_real))

In [23]:
# import tqdm as tqdm_sim

best_epoch_loss = 1

num_epochs = 50

for epoch in range(num_epochs):
    epoch_str = str(epoch + 1).rjust(len(str(num_epochs)), " ")
    with tqdm(total = len(segmentation_loader), desc = f"Epoch [ {epoch_str}/{num_epochs} ] : ") as pbar:
        segmentation_model.train()
        running_loss = 0.0
        for inputs, masks in segmentation_loader:
            inputs = inputs.to(device)
            masks = masks.to(device)
            
            # print(inputs.shape, masks.shape, outputs.shape)
            
            optimizer_segmentation.zero_grad()
            outputs = segmentation_model(inputs)
            loss = criterion_segmentation(outputs, masks.float())
            loss.backward()
            optimizer_segmentation.step()
            running_loss += loss.item() * inputs.size(0)
            pbar.update(1)
            
        epoch_loss = running_loss / len(segmentation_dataset)
        pbar.set_description(f"Epoch [ {epoch_str}/{num_epochs} ] ")
        pbar.set_postfix({'Train loss' : f'{epoch_loss:.3f}'})
        
        if epoch_loss <= best_epoch_loss:
            best_epoch_loss = epoch_loss
            train_str = f"{best_epoch_loss:.3f}"
            torch.save(segmentation_model.state_dict(), f"./Models/segmentation_model_train_{epoch}_{train_str}.pth")

KeyboardInterrupt: 

In [None]:
classification_model = create_detection_model(num_classes = 4)
classification_model.load_state_dict(torch.load("./Models/classification_model_train_92_100.0_test_99.2.pth", weights_only=True))
classification_model.to(device)

In [None]:
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

classification_model.eval()
segmentation_model.eval()

torch.set_grad_enabled(False)

classification_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

segmentation_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

test_image_path = './Dataset_Arghadip/Test/glioma/glioma (1).jpg'
original_image = Image.open(test_image_path).convert('RGB')

input_image_classification = classification_transform(original_image)
input_image_classification = input_image_classification.unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]

input_image_segmentation = segmentation_transform(original_image)
input_image_segmentation = input_image_segmentation.unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]

classification_output = classification_model(input_image_classification)
_, predicted_class_idx = torch.max(classification_output, 1)
predicted_class_idx = predicted_class_idx.item()

idx_to_class = {0: 'glioma', 1: 'meningioma', 2: 'pituitary', 3: 'notumor'}
predicted_class_name = idx_to_class[predicted_class_idx]

segmentation_output = segmentation_model(input_image_segmentation)  # Output shape: [1, 1, 224, 224]

segmentation_output = segmentation_output.squeeze(0).cpu()  # Shape: [1, 224, 224]
segmentation_output = torch.sigmoid(segmentation_output)  # Values between 0 and 1
threshold = 0.5
mask = (segmentation_output > threshold).float()
mask = mask.squeeze().numpy()  # Shape: [224, 224]

mask = Image.fromarray((mask * 255).astype(np.uint8))
mask = mask.resize(original_image.size, resample=Image.NEAREST)

overlay = original_image.copy()
overlay.putalpha(mask)

fig, axs = plt.subplots(1, 4, figsize=(20, 5))

axs[0].imshow(original_image)
axs[0].set_title('Original Image')
axs[0].axis('off')

axs[1].imshow(mask, cmap='gray')
axs[1].set_title('Segmentation Mask')
axs[1].axis('off')

axs[2].imshow(overlay)
axs[2].set_title('Overlay')
axs[2].axis('off')

axs[3].text(0.5, 0.5, f'Predicted Class:\n{predicted_class_name}', fontsize=18, ha='center')
axs[3].set_axis_off()

plt.show()