In [62]:
from skimage.segmentation import slic
from skimage.color import rgb2lab
from skimage.graph import rag_mean_color 
from skimage.measure import regionprops
from sklearn.preprocessing import StandardScaler
from skimage import graph
import torch
import numpy as np
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as GeometricDataLoader
from torch.utils.data import Dataset


X_train_np = np.load("/home/sajedhamdan/Desktop/skin_cancer/images_train_256x192.npy")
y_train_np = np.load("/home/sajedhamdan/Desktop/skin_cancer/train_labels.npy")

X_val_np = np.load("/home/sajedhamdan/Desktop/skin_cancer/images_val_256x192.npy")
y_val_np = np.load("/home/sajedhamdan/Desktop/skin_cancer/val_labels.npy")

X_test_np = np.load("/home/sajedhamdan/Desktop/skin_cancer/images_test_256x192.npy")
y_test_np = np.load("/home/sajedhamdan/Desktop/skin_cancer/test_labels.npy")


In [41]:
# a function to create graph data from an image
def create_graph_from_image(image_tensor, label_tensor):
    
    image_np = image_tensor.permute(1, 2, 0).numpy() 
    label = label_tensor.item() 

    segments = slic(image_np, n_segments=100, compactness=10, sigma=1, channel_axis=-1, enforce_connectivity=True)

    num_segments = np.max(segments) + 1
    image_lab = rgb2lab(image_np)

    node_features = np.zeros((num_segments, 6))

    for i in range(num_segments):
        mask = (segments == i)
        lab_pixels_in_segment = image_lab[mask]
    
        if lab_pixels_in_segment.size > 0:
            node_features[i, :3] = lab_pixels_in_segment.mean(axis=0) 
            node_features[i, 3:] = lab_pixels_in_segment.std(axis=0) 
        else:
            node_features[i, :] = 0.0 

    rag = graph.rag_mean_color(image_np, segments)
    
    edge_index = []
    for node1, node2 in rag.edges():
        edge_index.append([node1, node2])
        edge_index.append([node2, node1])

    if not edge_index:
        edge_index = torch.empty((2, 0), dtype=torch.long)
    else:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    data = Data(x=torch.tensor(node_features, dtype=torch.float32), 
                edge_index=edge_index, 
                y=torch.tensor([label], dtype=torch.long))
    return data


class CustomGraphDataset(Dataset):
    def __init__(self, image_data, label_data, scaler=None, train=True):
        self.image_data = image_data  
        self.label_data = label_data
        self.graphs = []
        self.scaler = scaler
        self.train = train

        print(f"Generating graphs for {'training' if train else 'validation'} dataset:")
        for i in range(len(self.image_data)):
           
            img_tensor = torch.tensor(self.image_data[i], dtype=torch.float32).permute(2, 0, 1)
            lbl_tensor = torch.tensor(self.label_data[i], dtype=torch.long)
            graph_data = create_graph_from_image(img_tensor, lbl_tensor)

            num_nodes = graph_data.x.size(0) 

            train_mask = torch.zeros(num_nodes, dtype=torch.bool)
            val_mask = torch.zeros(num_nodes, dtype=torch.bool)
            test_mask = torch.zeros(num_nodes, dtype=torch.bool)

            train_mask[:int(0.7 * num_nodes)] = 1 
            val_mask[int(0.7 * num_nodes):int(0.85 * num_nodes)] = 1
            test_mask[int(0.85 * num_nodes):] = 1 

            graph_data.train_mask = train_mask
            graph_data.val_mask = val_mask
            graph_data.test_mask = test_mask

            self.graphs.append(graph_data)

        print(f"Finished generating graphs for {'training' if train else 'validation'} dataset.")

        if self.train:  
            self.fit_scaler()
        
        self.transform_data()

    def fit_scaler(self):
        
        print("Fitting StandardScaler on training node features...")
        all_node_features = []
        for graph_data in self.graphs:
            if graph_data.x is not None and graph_data.x.numel() > 0:  
                all_node_features.append(graph_data.x.numpy()) 
        
        if all_node_features:
            combined_features = np.vstack(all_node_features)
            self.scaler = StandardScaler()
            self.scaler.fit(combined_features)
            print("StandardScaler fitted successfully.")
        else:
            print("No node features found to fit scaler.")

    def transform_data(self):
       
        if self.scaler is not None:
            print(f"Transforming node features for {'training' if self.train else 'validation'} dataset:")
            for i, graph_data in enumerate(self.graphs):
                if graph_data.x is not None and graph_data.x.numel() > 0:  
                    transformed_features = self.scaler.transform(graph_data.x.numpy())
                    self.graphs[i].x = torch.tensor(transformed_features, dtype=torch.float32)
                else:
                    print(f"Warning: Graph {i} has no nodes (empty graph.x).")
            print(f"Finished transforming node features for {'training' if self.train else 'validation'} dataset.")
        elif not self.train:
            print("Warning: No scaler provided for validation dataset.")

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx]


In [42]:
# custom datasets
train_dataset = CustomGraphDataset(X_train_np, y_train_np, train=True)
val_dataset = CustomGraphDataset(X_val_np, y_val_np, scaler=train_dataset.scaler, train=False)

# dataLoaders for graph data 
train_loader = GeometricDataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = GeometricDataLoader(val_dataset, batch_size=32)

print(f"Number of training graphs: {len(train_dataset)}")
print(f"Number of validation graphs: {len(val_dataset)}")

Generating graphs for training dataset...
Finished generating graphs for training dataset.
Fitting StandardScaler on training node features...
StandardScaler fitted successfully.
Transforming node features for training dataset...
Finished transforming node features for training dataset.
Generating graphs for validation dataset...
Finished generating graphs for validation dataset.
Transforming node features for validation dataset...
Finished transforming node features for validation dataset.

--- Data Loading Complete ---
Number of training graphs: 8111
Number of validation graphs: 902
Sample training batch (from DataLoader): DataBatch(x=[3415, 6], edge_index=[2, 17868], y=[32], train_mask=[3415], val_mask=[3415], test_mask=[3415], batch=[3415], ptr=[33])
Shape of node features in sample batch: torch.Size([3415, 6])
Number of graphs in sample batch: 32


In [63]:
test_dataset = CustomGraphDataset(X_test_np, y_test_np, scaler=train_dataset.scaler, train=False)
test_loader = GeometricDataLoader(test_dataset, batch_size=32)

Generating graphs for validation dataset...
Finished generating graphs for validation dataset.
Transforming node features for validation dataset...
Finished transforming node features for validation dataset.


In [145]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.linear = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        if batch is not None:
            x = global_mean_pool(x, batch)

        out = self.linear(x)
        return out


In [146]:
num_features = train_dataset.graphs[0].x.shape[1]  
num_classes = len(set(train_dataset.label_data))   


model = GCN(
    in_channels=num_features,  
    hidden_channels=args.hidden_channels,
    out_channels=num_classes,  
).to(device)

optimizer = torch.optim.Adam([
    dict(params=model.conv1.parameters(), weight_decay=5e-4),  
    dict(params=model.conv2.parameters(), weight_decay=0)],   
    lr=0.0001)


scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)


In [147]:
def train():
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        
        loss = F.cross_entropy(out, data.y)
        loss.backward()  
        optimizer.step() 
        
        total_loss += loss.item()

    return total_loss / len(train_loader)

@torch.no_grad()
def test():
    model.eval() 
    
    all_preds = [[], [], []] 
    all_labels = [[], [], []]  

    for loader_idx, loader in enumerate([train_loader, val_loader, test_loader]):
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            pred = out.argmax(dim=1)  
            labels = data.y  

            all_preds[loader_idx].extend(pred.cpu().tolist())
            all_labels[loader_idx].extend(labels.cpu().tolist())

    # accuracy, precision and recall for all datasets
    results = []
    for preds, labels in zip(all_preds, all_labels):
        acc = (sum(p == l for p, l in zip(preds, labels)) / len(labels))
        prec = precision_score(labels, preds, average='macro', zero_division=0) 
        rec = recall

In [149]:
from sklearn.metrics import accuracy_score, precision_score, recall_score

best_val_acc = test_acc = 0
times = []
epochs = 50 
best_epoch = 0

for epoch in range(epochs):
    model.train() 
    total_train_loss = 0
    all_train_labels = []
    all_train_preds = []

    # training phase
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        
        loss = F.cross_entropy(out, data.y)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
        
        pred = out.argmax(dim=1)
        all_train_labels.append(data.y.cpu().numpy())
        all_train_preds.append(pred.cpu().numpy())

    all_train_labels = np.concatenate(all_train_labels)
    all_train_preds = np.concatenate(all_train_preds)
    train_accuracy = accuracy_score(all_train_labels, all_train_preds)
    train_precision = precision_score(all_train_labels, all_train_preds, average='weighted', zero_division=1)
    train_recall = recall_score(all_train_labels, all_train_preds, average='weighted', zero_division=1)

    # validation phase
    model.eval() 
    total_val_loss = 0
    all_val_labels = []
    all_val_preds = []
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            val_loss = F.cross_entropy(out, data.y)
            total_val_loss += val_loss.item()

            pred = out.argmax(dim=1)
            all_val_labels.append(data.y.cpu().numpy())
            all_val_preds.append(pred.cpu().numpy())

    val_loss = total_val_loss / len(val_loader)
    
    all_val_labels = np.concatenate(all_val_labels)
    all_val_preds = np.concatenate(all_val_preds)
    val_accuracy = accuracy_score(all_val_labels, all_val_preds)
    val_precision = precision_score(all_val_labels, all_val_preds, average='weighted', zero_division=1)
    val_recall = recall_score(all_val_labels, all_val_preds, average='weighted', zero_division=1)

    # updating learning rate based on validation loss
    scheduler.step(val_loss)

    # test phase
    model.eval()  
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.linear = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        if batch is not None:
            x = global_mean_pool(x, batch)

        out = self.linear(x)
        return out

        class ImprovedGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(ImprovedGCN, self).__init__()
    all_test_labels = []
    all_test_preds = []
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)

            pred = out.argmax(dim=1)
            all_test_labels.append(data.y.cpu().numpy())
            all_test_preds.append(pred.cpu().numpy())

    all_test_labels = np.concatenate(all_test_labels)
    all_test_preds = np.concatenate(all_test_preds)
    test_accuracy = accuracy_score(all_test_labels, all_test_preds)
    test_precision = precision_score(all_test_labels, all_test_preds, average='weighted', zero_division=1)
    test_recall = recall_score(all_test_labels, all_test_preds, average='weighted', zero_division=1)

    # performance metrics
    print(f'Epoch {epoch+1}/{epochs}')
    print(f'Train Loss: {total_train_loss / len(train_loader):.4f}, Train Accuracy: {train_accuracy:.4f}, Train Precision: {train_precision:.4f}, Train Recall: {train_recall:.4f}')
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, Validation Precision: {val_precision:.4f}, Validation Recall: {val_recall:.4f}')
    print(f'Test Accuracy: {test_accuracy:.4f}, Test Precision: {test_precision:.4f}, Test Recall: {test_recall:.4f}')

    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        test_acc = test_accuracy
        best_epoch = epoch

print(f'Final test accuracy: {test_acc:.4f}')


Epoch 1/50
Train Loss: 0.8175, Train Accuracy: 0.7029, Train Precision: 0.6493, Train Recall: 0.7029
Validation Loss: 0.8403, Validation Accuracy: 0.6785, Validation Precision: 0.6056, Validation Recall: 0.6785
Test Accuracy: 0.7046, Test Precision: 0.6716, Test Recall: 0.7046
Epoch 2/50
Train Loss: 0.8159, Train Accuracy: 0.7021, Train Precision: 0.6492, Train Recall: 0.7021
Validation Loss: 0.8394, Validation Accuracy: 0.6785, Validation Precision: 0.6057, Validation Recall: 0.6785
Test Accuracy: 0.7006, Test Precision: 0.6354, Test Recall: 0.7006
Epoch 3/50
Train Loss: 0.8161, Train Accuracy: 0.7030, Train Precision: 0.6487, Train Recall: 0.7030
Validation Loss: 0.8388, Validation Accuracy: 0.6796, Validation Precision: 0.6081, Validation Recall: 0.6796
Test Accuracy: 0.7056, Test Precision: 0.6452, Test Recall: 0.7056
Epoch 4/50
Train Loss: 0.8147, Train Accuracy: 0.7035, Train Precision: 0.6505, Train Recall: 0.7035
Validation Loss: 0.8381, Validation Accuracy: 0.6796, Validation 

In [None]:
torch.save(model, 'GCN.pth')