## CC- Net Reimplementation


- Reimplement from section 4

In [None]:
import kagglehub
import os
import torch
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader, random_split, ConcatDataset
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from google.colab import drive
import random


In [None]:
drive.mount('/content/drive')
print(os.listdir('/content/drive/MyDrive'))
# os.chdir('/content/drive/MyDrive/Class Materials/Fall 25/STOR566/STOR566Project') # Please update this line with the correct path

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
['College Application', 'Stats', 'CompSci', 'Class Materials', 'Colab Notebooks', 'Miscellaneous.gdoc', '剧本杀.gdoc']


In [None]:
seed = 2025
torch.manual_seed(seed), np.random.seed(seed), random.seed(seed)

(<torch._C.Generator at 0x7c322b788f50>, None, None)

## Dataset reformat and dataloader

In [None]:
def handpd_loaders(path, shuffle_test = False):

    # reformat
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.225, 0.225, 0.225])
    # paper resized to 227 * 227
    transform = transforms.Compose([
        transforms.Resize((227, 227)),
        transforms.ToTensor(),
        normalize
    ])

    full_dataset = ImageFolder(path, transform = transform)

    # 80-20 split
    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

    # loaders
    train_loader = DataLoader(train_dataset, batch_size = 32, shuffle = True, pin_memory = True)
    test_loader  = DataLoader(test_dataset, batch_size = 32, shuffle = shuffle_test, pin_memory = True)

    # properties
    print(f"Total samples: {len(full_dataset)}")
    print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
    print(f"Classes: {full_dataset.classes}")

    return train_loader, test_loader, full_dataset


In [None]:
base_path = '/content/drive/MyDrive/Class Materials/Fall 25/STOR566/STOR566Project/Parkinsons'
train_loader, test_loader, full_dataset = handpd_loaders(base_path)

Total samples: 3264
Train samples: 2611, Test samples: 653
Classes: ['Healthy', 'Parkinson']


In [None]:
from collections import Counter

In [None]:
counts = Counter(full_dataset.targets)
print("Class index mapping:", full_dataset.class_to_idx)
print("Total distribution:", counts)

Class index mapping: {'Healthy': 0, 'Parkinson': 1}
Total distribution: Counter({0: 1632, 1: 1632})


In [None]:
images, labels = next(iter(train_loader))
dataset = ImageFolder(base_path)
# labels -- 0(healthy) or 1(Parkinson)
print(images.shape, labels.shape)
print(dataset.classes)

torch.Size([32, 3, 227, 227]) torch.Size([32])
['Healthy', 'Parkinson']


## Visualize one example

In [None]:
# Pick the first image
img = images[31]
lbl = labels[31]

# Undo normalization for display
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.225, 0.225, 0.225])
#transform
img = img.numpy().transpose((1, 2, 0))
img = std * img + mean
img = np.clip(img, 0, 1)
# Access the original dataset classes
classes = train_loader.dataset.dataset.classes

# Plot the image
plt.imshow(img)
plt.title(f"Label: {classes[lbl]}")
plt.axis('off')
plt.show()

### Preprocessing not possible here becasue the dataset is combined and not applicable for the properties and only applied to spiral images

## Model Implementation

In [None]:
class handpd_ccnet(nn.Module):
  def __init__(self, num_classes = 2):
    super(handpd_ccnet, self).__init__()

    self.classifier = nn.Sequential(
        # module A
        # layer 1, 227 -> 56
        nn.Conv2d(in_channels = 3, out_channels = 48, kernel_size = 11, stride = 4, padding = 2),
        nn.ReLU(),
        #layer 2 56 -> 28
        nn.Conv2d(in_channels = 48, out_channels = 128, kernel_size = 5, stride = 2, padding = 2),
        nn.ReLU(),
        #Layer 3 28 -> 9
        nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size = 3, stride = 3),

        #module B
        #layer 1 9 -> 9
        nn.Conv2d(in_channels = 128, out_channels = 192, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        #layer 2 9 -> 9
        nn.Conv2d(in_channels = 192, out_channels = 192, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(),
        # layer 3 9 -> 5 -> 1
        nn.Conv2d(in_channels = 192, out_channels = 128, kernel_size = 3, stride = 2, padding = 1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size = 3, stride = 3),

        # optional since 1*1
        nn.Flatten(),
        # FC
        nn.Linear(in_features = 128 * 1 * 1, out_features = num_classes)
        )

  def forward(self, x):
    return self.classifier(x)



## Training

In [None]:
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
cc_net_model = handpd_ccnet().to(device)
print(cc_net_model)

In [None]:
# evaluation matrix
#Accuracy, precision, recall, F1-score
learning_rate = 0.0001
# originally 2000
num_of_epoch = 50
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(cc_net_model.parameters(), lr = learning_rate)

total_loss = []
for epoch in tqdm(range(num_of_epoch)):
    cc_net_model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = cc_net_model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    total_loss.append(avg_loss)

    print(f"Epoch [{epoch+1}/{num_of_epoch}], Loss: {avg_loss:.4f}")



In [None]:
# Loss Tracking Plot
x = np.arange(num_of_epoch)
y = total_loss
plt.plot(x, y)
plt.title("Avergae Epoch Loss")
plt.xlabel("number of epochs")
plt.ylabel("loss")
plt.show()

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score

In [None]:
# Testing Accuracy
cc_net_model.eval()

# Collect data for visuals
correct = 0
total = 0
total_predictions = []
total_labels = []
probs = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = cc_net_model(images)
        _, predicted = torch.max(outputs.data, 1)

        softmax_probs = F.softmax(outputs, dim=1)[:,1]

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        total_predictions.extend(predicted.cpu().numpy())
        total_labels.extend(labels.cpu().numpy())
        probs.extend(softmax_probs.cpu().numpy())

accuracy = correct / total
print(f"Test Accuracy: {accuracy*100:.2f}%")


In [None]:
# performance measurmenet
y_true = np.array(total_labels)
y_pred = np.array(total_predictions)
precision = precision_score(y_true, y_pred, average='binary')
recall = recall_score(y_true, y_pred, average='binary')
f1 = f1_score(y_true, y_pred, average='binary')


auc = roc_auc_score(y_true, probs)

print("Evaluation Metrics:")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1-score:  {f1:.4f}")
print(f"AUC:       {auc:.4f}")


In [None]:
def ccnet_visuals(image, title = None):
    image = image.numpy().transpose((1, 2, 0))
    image = image * np.array([0.225, 0.225, 0.225]) + np.array([0.485, 0.456, 0.406])
    image = np.clip(image, 0, 1)
    plt.imshow(image)
    if title:
        plt.title(title)
    plt.axis('off')

# take first batch as examples
dataiter = iter(test_loader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)
outputs = cc_net_model(images)
_, preds = torch.max(outputs, 1)

plt.figure(figsize=(12, 6))
for i in range(8):
    plt.subplot(2, 4, i+1)
    ccnet_visuals(images[i].cpu(), title=f"Pred: {preds[i].item()}, True: {labels[i].item()}")
plt.tight_layout()
plt.show()