In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import os
import numpy as np

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
class ClassificationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_to_idx = {
            'glioma': 0,
            'meningioma': 1,
            'pituitary': 2,
            'notumor': 3
        }
        for class_name in os.listdir(root_dir):
            class_dir = os.path.join(root_dir, class_name)
            if os.path.isdir(class_dir) and class_name in self.class_to_idx:
                for img_name in os.listdir(class_dir):
                    img_path = os.path.join(class_dir, img_name)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])
                        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label
    
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
            
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Convert mask to grayscale
        if self.transform:
            image, mask = self.transform(image, mask)
        return image, mask

In [4]:
class MICRO_SEM(nn.Module):
    def __init__(self, num_in):
        super(MICRO_SEM, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels = num_in, out_channels = 32, kernel_size = (3, 3), padding = 'same')
        self.relu = nn.ReLU()
        self.batch_norm_1 = nn.BatchNorm2d(32)
        self.pool_1 = nn.MaxPool2d(kernel_size = (3, 3))
        
        self.conv_2 = nn.Conv2d(in_channels = 32, out_channels = 48, kernel_size = (3, 3), padding = 'same')
        self.batch_norm_2 = nn.BatchNorm2d(48)
        self.pool_2 = nn.MaxPool2d(kernel_size = (3, 3))
        
        self.conv_3 = nn.Conv2d(in_channels = 48, out_channels = 96, kernel_size = (3, 3), padding = 'same')
        self.batch_norm_3 = nn.BatchNorm2d(96)
        self.pool_3 = nn.MaxPool2d(kernel_size = (3, 3))
        
    def forward(self, x):
        x = self.conv_1(x)
        x = self.relu(x)
        x = self.batch_norm_1(x)
        x = self.pool_1(x)
        
        x = self.conv_2(x)
        x = self.relu(x)
        x = self.batch_norm_2(x)
        x = self.pool_2(x)
        
        x = self.conv_3(x)
        x = self.relu(x)
        x = self.batch_norm_3(x)
        x = self.pool_3(x)
        
        return x
    
class MESO_SEM(nn.Module):
    def __init__(self, num_in):
        super(MESO_SEM, self).__init__()
        
        self.conv_1_a = nn.Conv2d(in_channels = num_in, out_channels = 32, kernel_size = (3, 3), padding = 'same')
        self.conv_1_b = nn.Conv2d(in_channels = 32, out_channels = 48, kernel_size = (3, 3), padding = 'same')
        self.batch_norm_1 = nn.BatchNorm2d(48)
        self.relu = nn.ReLU()
        self.pool_1 = nn.MaxPool2d(3, 3)
        
        self.conv_2_a = nn.Conv2d(in_channels = 48, out_channels = 64, kernel_size = (3, 3), padding = 'same')
        self.conv_2_b = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3, 3), padding = 'same')
        self.batch_norm_2 = nn.BatchNorm2d(64)
        self.pool_2 = nn.MaxPool2d(3, 3)
        
        self.resize = nn.Sequential(
            nn.Conv2d(in_channels = 48, out_channels = 64, kernel_size = (3, 3), stride = 3),
            nn.ReLU(),
            nn.BatchNorm2d(64)
            )
        
        self.conv_3_a = nn.Conv2d(in_channels = 64, out_channels = 96, kernel_size = (3, 3), padding = 'same')
        self.conv_3_b = nn.Conv2d(in_channels = 96, out_channels = 96, kernel_size = (3, 3), padding = 'same')
        self.batch_norm_3 = nn.BatchNorm2d(96)
        self.pool_3 = nn.MaxPool2d(3, 3)
        
    def forward(self, x):
        x = self.conv_1_a(x)
        x = self.relu(x)      
        x = self.conv_1_b(x)
        x = self.relu(x)
        x = self.batch_norm_1(x)
        x = self.pool_1(x)
        
        x_resize = self.resize(x)
        
        x = self.conv_2_a(x)
        x = self.relu(x)      
        x = self.conv_2_b(x)
        x = self.relu(x)
        x = self.batch_norm_2(x)
        x = self.pool_2(x)
        
        x = torch.add(x, x_resize)
        
        x = self.conv_3_a(x)
        x = self.relu(x)      
        x = self.conv_3_b(x)
        x = self.relu(x)
        x = self.batch_norm_3(x)
        x = self.pool_3(x)
        
        return x

class MACRO_SEM(nn.Module):
    def __init__(self, num_in):
        super(MACRO_SEM, self).__init__()
        
        self.conv_1_a = nn.Conv2d(in_channels = num_in, out_channels = 32, kernel_size = (5, 5), padding = 'same')
        self.conv_1_b = nn.Conv2d(in_channels = 32, out_channels = 48, kernel_size = (5, 5), padding = 'same')
        self.batch_norm_1 = nn.BatchNorm2d(48)
        self.relu = nn.ReLU()
        self.pool_1 = nn.MaxPool2d(2, 2)
        
        self.resize_1 = nn.Sequential(
            nn.Conv2d(in_channels = 48, out_channels = 64, kernel_size = (2, 2), stride = 2),
            nn.ReLU(),
            nn.BatchNorm2d(64)
            )
        
        self.conv_2_a = nn.Conv2d(in_channels = 48, out_channels = 64, kernel_size = (5, 5), padding = 'same')
        self.conv_2_b = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (5, 5), padding = 'same')
        self.batch_norm_2 = nn.BatchNorm2d(64)
        self.pool_2 = nn.MaxPool2d(2, 2)
        
        self.resize_2 = nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 96, kernel_size = (2, 2), stride = 2),
            nn.ReLU(),
            nn.BatchNorm2d(96)
        )
        
        self.conv_3_a = nn.Conv2d(in_channels = 64, out_channels = 96, kernel_size = (5, 5), padding = 'same')
        self.conv_3_b = nn.Conv2d(in_channels = 96, out_channels = 96, kernel_size = (5, 5), padding = 'same')
        self.batch_norm_3 = nn.BatchNorm2d(96)
        self.pool_3 = nn.MaxPool2d(2, 2)
        
        self.conv_4_a = nn.Conv2d(in_channels = 96, out_channels = 128, kernel_size = (5, 5), padding = 'same')
        self.conv_4_b = nn.Conv2d(in_channels = 128, out_channels = 96, kernel_size = (5, 5), padding = 'same')
        self.batch_norm_4 = nn.BatchNorm2d(96)
        self.pool_4 = nn.MaxPool2d(kernel_size = (7, 7), stride = (3, 3))
        
    def forward(self, x):
        x = self.conv_1_a(x)
        x = self.relu(x)
        x = self.conv_1_b(x)
        x = self.relu(x)
        x = self.batch_norm_1(x)
        x = self.pool_1(x)
        
        x_resize_1 = self.resize_1(x)
        
        x = self.conv_2_a(x)
        x = self.relu(x)      
        x = self.conv_2_b(x)
        x = self.relu(x)
        x = self.batch_norm_2(x)
        x = self.pool_2(x)
        
        x_conv_2 = x
        x_resize_2 = self.resize_2(x_conv_2)

        x = torch.add(x, x_resize_1)
        
        x = self.conv_3_a(x)
        x = self.relu(x)      
        x = self.conv_3_b(x)
        x = self.relu(x)
        x = self.batch_norm_3(x)
        x = self.pool_3(x)
        x = torch.add(x, x_resize_2)
        
        x = self.conv_4_a(x)
        x = self.relu(x)      
        x = self.conv_4_b(x)
        x = self.relu(x)
        x = self.batch_norm_4(x)
        x = self.pool_4(x)
        
        return x
    
class BrinTumor(nn.Module):
    def __init__(self, micro, macro, meso, num_in, num_out):
        super(BrinTumor, self).__init__()
        self.micro_layer = self._make_layer(micro, num_in)
        self.macro_layer = self._make_layer(macro, num_in)
        self.meso_layer = self._make_layer(meso, num_in)
        
        self.flatten = nn.Flatten(1)
        
        self.fc_1 = nn.Linear(in_features = 288 * 8 * 8, out_features = 4096)
        self.drop_out = nn.Dropout(0.5)
        self.relu = nn.ReLU()
        self.fc_2 = nn.Linear(in_features = 4096, out_features = 512)
        self.out_layer = nn.Linear(in_features = 512, out_features = num_out)
        
    def forward(self, x):
        x_macro = self.macro_layer(x)
        x_meso = self.meso_layer(x)
        x_micro = self.micro_layer(x)
        
        x = torch.cat((x_micro, x_meso, x_macro), 1)
        
        # print(x.shape)
        
        x = self.flatten(x)
        x = self.drop_out(x)

        x = self.fc_1(x)
        x = self.relu(x)
        x = self.drop_out(x)
        x = self.fc_2(x)
        x = self.relu(x)
        out = self.out_layer(x)
        
        return out

    def _make_layer(self, Module, num_in):
        # layers = []
        
        self.in_channels = num_in
        # layers.append(Module(self.in_channels))
        layers = Module(self.in_channels)
                    
        return nn.Sequential(layers)
    
def BrinTumor_Main(num_in, num_out):
    return BrinTumor(MICRO_SEM, MACRO_SEM, MESO_SEM, num_in, num_out)

In [5]:
classification_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

class SegmentationTransform:
    def __init__(self):
        self.image_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ])
            
    def __call__(self, image, mask):
        image = self.image_transform(image)
        mask = mask.resize((224, 224), resample=Image.NEAREST)
        mask = np.array(mask, dtype=np.float32)
        mask = np.expand_dims(mask, axis=0)  # Add channel dimension
        mask = torch.from_numpy(mask)
        mask = mask / 255.0  # Scale mask to [0, 1] range if needed
        return image, mask

train_dataset = ClassificationDataset(root_dir='./Dataset_Arghadip/Train', transform=classification_transform)
val_dataset = ClassificationDataset(root_dir='./Dataset_Arghadip/Validation', transform=classification_transform)
test_dataset = ClassificationDataset(root_dir='./Dataset_Arghadip/Test', transform=classification_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

segmentation_image_dir = './Segmentaion_dataset/images'
segmentation_mask_dir = './Segmentaion_dataset/masks'

segmentation_transform = SegmentationTransform()
segmentation_dataset = SegmentationDataset(image_dir=segmentation_image_dir,
                                           mask_dir=segmentation_mask_dir,
                                           transform=segmentation_transform)

segmentation_loader = DataLoader(segmentation_dataset, batch_size=8, shuffle=True)

In [6]:
# classification_model = models.resnet18(pretrained=True)
# num_ftrs = classification_model.fc.in_features
# classification_model.fc = nn.Linear(num_ftrs, 4)  # 4 classes
# classification_model = classification_model.to(device)

In [7]:
# segmentation_model = models.segmentation.fcn_resnet50(pretrained=False, num_classes=1)
# segmentation_model = segmentation_model.to(device)

In [8]:
# model1 = MICRO_SEM(3)
# model2 = MESO_SEM(3)
# model3 = MACRO_SEM(3)

In [9]:
# from torchinfo import summary
# summary_ = summary(model3, input_size = (32, 3, 224, 224))
# summary_

In [None]:
classification_model = BrinTumor_Main(3, 4)
classification_model = classification_model.to(device)
segmentation_model = models.segmentation.fcn_resnet50(pretrained = False, num_classes=1)
segmentation_model = segmentation_model.to(device)

In [None]:
from torchinfo import summary
summary_ = summary(classification_model, input_size = (32, 3, 224, 224))
summary_

In [12]:
criterion_classification = nn.CrossEntropyLoss()
optimizer_classification = optim.Adam(classification_model.parameters(), lr=1e-4)

criterion_segmentation = nn.BCEWithLogitsLoss()
optimizer_segmentation = optim.Adam(segmentation_model.parameters(), lr=1e-4)

In [13]:
# num_epochs = 10

# for epoch in range(num_epochs):
#     model.train()
#     running_loss = 0.0
#     running_corrects = 0
    
#     for inputs, labels in train_loader:
#         inputs = inputs.to(device)
#         labels = labels.to(device)
        
#         optimizer.zero_grad()
        
#         outputs = model(inputs)
#         _, preds = torch.max(outputs, 1)
#         loss = criterion(outputs, labels)
        
#         loss.backward()
#         optimizer.step()
        
#         running_loss += loss.item() * inputs.size(0)
#         running_corrects += torch.sum(preds == labels.data)
    
#     epoch_loss = running_loss / len(train_loader)
#     epoch_acc = running_corrects.double() / len(train_loader)
    
#     print('Epoch {}/{}: Loss: {:.4f} Acc: {:.4f}'.format(
#         epoch+1, num_epochs, epoch_loss, epoch_acc))


In [None]:
from tqdm.notebook import tqdm

num_epochs = 100
total_len = len(train_loader) + len(val_loader)

train_losses = np.zeros(num_epochs)
test_losses = np.zeros(num_epochs)
train_acc = np.zeros(num_epochs)
test_acc = np.zeros(num_epochs)

best_test_acc = 0

for epoch in range(num_epochs):
    epoch_str = str(epoch + 1).rjust(len(str(num_epochs)), " ")
    with tqdm(total = total_len, desc = f"Epoch [ {epoch_str}/{num_epochs} ] : ") as pbar:
        
        classification_model.train()
        n_correct = 0
        n_total = 0
        train_loss = []
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer_classification.zero_grad()
            
            outputs = classification_model(inputs)
            loss = criterion_classification(outputs, targets)
            train_loss.append(loss.item())
            
            loss.backward()
            optimizer_classification.step()
            
            _, prediction = torch.max(outputs, 1)
            
            n_correct += (prediction == targets).sum().item()
            n_total += targets.shape[0]
            
            pbar.update(1)
        
        # before_lr = optimizer.param_groups[0]["lr"]
        # scheduler.step()
        # after_lr = optimizer.param_groups[0]["lr"]
        
        train_loss = np.mean(train_loss)
        train_acc_real = n_correct/n_total
        train_losses[epoch] = train_loss
        train_acc[epoch] = train_acc_real
                
        classification_model.eval()
        n_correct = 0
        n_total = 0
        test_loss = []
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                
                outputs = classification_model(inputs)
                
                loss = criterion_classification(outputs, targets)
                test_loss.append(loss.item())
                
                _, prediction = torch.max(outputs, 1)
                
                n_correct += (prediction == targets).sum().item()
                n_total += targets.shape[0]
                
                pbar.update(1)
            
        test_loss = np.mean(test_loss)
        test_acc_real = n_correct/n_total
        test_losses[epoch] = test_loss
        test_acc[epoch] = test_acc_real
        
        pbar.set_description(f"Epoch [ {epoch_str}/{num_epochs} ] ")
        pbar.set_postfix({'Train acc' : f'{train_acc_real:.3f}',
                          'Train loss' : f'{train_loss:.3f}',
                          'Test acc' : f'{test_acc_real:.3f}',
                          'Test loss': f'{test_loss:.3f}'
                        #   'lr_before' : f'{before_lr}',
                        #   'lr_after' : f'{after_lr}'+
                         })

In [None]:
classification_model.eval()
val_running_corrects = 0

n_correct = 0
n_total = 0

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
                
        outputs = classification_model(inputs)

        _, prediction = torch.max(outputs, 1)
                
        n_correct += (prediction == targets).sum().item()
        n_total += targets.shape[0]

test_acc_real = n_correct/n_total
print('Validation Accuracy: {:.4f}'.format(test_acc_real))

In [None]:
import tqdm as tqdm_sim

for epoch in range(num_epochs):
    segmentation_model.train()
    running_loss = 0.0
    for i, (inputs, masks) in tqdm_sim.tqdm(enumerate(segmentation_loader), total = len(segmentation_loader)):
        inputs = inputs.to(device)
        masks = masks.to(device)
        optimizer_segmentation.zero_grad()
        outputs = segmentation_model(inputs)['out']
        loss = criterion_segmentation(outputs, masks.float())
        loss.backward()
        optimizer_segmentation.step()
        running_loss += loss.item() * inputs.size(0)
        
    epoch_loss = running_loss / len(segmentation_dataset)
    print('Segmentation Epoch {}/{} Loss: {:.4f}'.format(epoch+1, num_epochs, epoch_loss))

In [None]:
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

classification_model.eval()
segmentation_model.eval()

torch.set_grad_enabled(False)

classification_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

segmentation_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

test_image_path = './Dataset_Arghadip/Test/glioma/glioma (1).jpg'
original_image = Image.open(test_image_path).convert('RGB')

input_image_classification = classification_transform(original_image)
input_image_classification = input_image_classification.unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]

input_image_segmentation = segmentation_transform(original_image)
input_image_segmentation = input_image_segmentation.unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]

classification_output = classification_model(input_image_classification)
_, predicted_class_idx = torch.max(classification_output, 1)
predicted_class_idx = predicted_class_idx.item()

idx_to_class = {0: 'glioma', 1: 'meningioma', 2: 'pituitary', 3: 'notumor'}
predicted_class_name = idx_to_class[predicted_class_idx]

segmentation_output = segmentation_model(input_image_segmentation)['out']  # Output shape: [1, 1, 224, 224]

segmentation_output = segmentation_output.squeeze(0).cpu()  # Shape: [1, 224, 224]
segmentation_output = torch.sigmoid(segmentation_output)  # Values between 0 and 1
threshold = 0.5
mask = (segmentation_output > threshold).float()
mask = mask.squeeze().numpy()  # Shape: [224, 224]

mask = Image.fromarray((mask * 255).astype(np.uint8))
mask = mask.resize(original_image.size, resample=Image.NEAREST)

overlay = original_image.copy()
overlay.putalpha(mask)

fig, axs = plt.subplots(1, 4, figsize=(20, 5))

axs[0].imshow(original_image)
axs[0].set_title('Original Image')
axs[0].axis('off')

axs[1].imshow(mask, cmap='gray')
axs[1].set_title('Segmentation Mask')
axs[1].axis('off')

axs[2].imshow(overlay)
axs[2].set_title('Overlay')
axs[2].axis('off')

axs[3].text(0.5, 0.5, f'Predicted Class:\n{predicted_class_name}', fontsize=18, ha='center')
axs[3].set_axis_off()

plt.show()
