<a href="https://colab.research.google.com/github/JacGuo/Thesis_COMP7882/blob/main/COMP7882_taskrelevant.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# import pytorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

print(torch.__version__)

2.4.0+cu121


In [2]:
import torch.nn.functional as F
from tqdm import tqdm
import torchvision.models as models
from torchvision.models import ResNet18_Weights

In [3]:
# Check if CUDA is available and if device is GPU
print('Cuda Available : {}'.format(torch.cuda.is_available()))
print('GPU - {0}'.format(torch.cuda.get_device_name()))

Cuda Available : True
GPU - Tesla T4


In [4]:
device = torch.device("cuda")

**Load CIFAR-10 dataset**

In [5]:
# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)

# Define transformations for the CIFAR-10 dataset with proper normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-10 specific normalization
])

# Load the CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Split the training dataset into training and validation sets
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Define the mapping from CIFAR-10 labels to binary labels
def map_labels_to_binary(target):
    animal_classes = [2, 3, 4, 5, 6, 7]  # 'bird', 'cat', 'deer', 'dog', 'frog', 'horse'
    if target in animal_classes:
        return 0  # Animal
    else:
        return 1  # Vehicle

# Map the datasets to binary labels
train_dataset_binary = [(data, map_labels_to_binary(target)) for data, target in train_dataset]
val_dataset_binary = [(data, map_labels_to_binary(target)) for data, target in val_dataset]
test_dataset_binary = [(data, map_labels_to_binary(target)) for data, target in test_dataset]

# Create data loaders
train_loader_binary = DataLoader(train_dataset_binary, batch_size=64, shuffle=True)
val_loader_binary = DataLoader(val_dataset_binary, batch_size=64, shuffle=False)
test_loader_binary = DataLoader(test_dataset_binary, batch_size=64, shuffle=False)

train_loader_multiclass = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader_multiclass = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader_multiclass = DataLoader(test_dataset, batch_size=64, shuffle=False)

Using device: cuda
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 30987531.60it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Define **Feature Extractor**

In [6]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        # Use ResNet-18 with the updated method for loading pretrained weights
        self.resnet = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        # Remove the fully connected layer (fc) from ResNet-18
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
        # Define a single fully connected layer to output features
        self.fc = nn.Linear(512, 256)

    def forward(self, x):
        # Pass input through ResNet-18 (excluding the final layer)
        x = self.resnet(x)
        x = torch.flatten(x, 1)  # Flatten the output
        # Output features directly
        features = self.fc(x)
        return features


def kl_divergence(features):
    # Implement a KL divergence based on the feature distribution
    mean = torch.mean(features, dim=0)
    std = torch.std(features, dim=0)
    kl_div = torch.sum(mean**2 + std**2 - torch.log(std**2) - 1)
    return kl_div


Define the Target Model - Binary Classification

In [7]:
class TargetModelBinary(nn.Module):
    def __init__(self):
        super(TargetModelBinary, self).__init__()
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)  # Binary classification output

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # Sigmoid for binary classification
        return x

Define Multiclass Image Classification (10 classes)

In [8]:
class DifferentTaskModel(nn.Module):
    def __init__(self):
        super(DifferentTaskModel, self).__init__()
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)  # Output layer for 10 classes

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # No sigmoid or softmax here, as CrossEntropyLoss expects logits
        return x

**Training**

Create the model instance

In [9]:

# Instantiate models and move them to the device
feature_extractor = FeatureExtractor().to(device)
target_model_binary = TargetModelBinary().to(device)
different_task_model = DifferentTaskModel().to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 150MB/s]


Define the loss function and optimizer

In [10]:
# Define the binary cross-entropy loss for binary classification
criterion_binary = nn.BCELoss()

# Define the cross-entropy loss for 10-class classification
criterion_multiclass = nn.CrossEntropyLoss()

# Define the overall loss function for binary classification
def total_loss_binary(features, task_output, target, lambda_val=0.01):
    kl_loss = kl_divergence(features)
    task_loss = criterion_binary(task_output, target.float().view(-1, 1))
    return task_loss + lambda_val * kl_loss



# Define optimizers
optimizer_binary = optim.Adam(list(feature_extractor.parameters()) + list(target_model_binary.parameters()), lr=1e-4)



Training loop

In [11]:
# Number of epochs to train the model
num_epochs = 10

# Training loop for binary classification
for epoch in range(num_epochs):
    feature_extractor.train()
    target_model_binary.train()

    running_loss = 0.0

    for data, targets in tqdm(train_loader_binary, desc=f"Epoch {epoch+1} - Binary Task"):
        data, targets = data.to(device), targets.to(device)

        # Zero the parameter gradients
        optimizer_binary.zero_grad()

        # Forward pass
        features = feature_extractor(data)
        outputs = target_model_binary(features)

        # Compute the loss
        loss = total_loss_binary(features, outputs, targets)

        # Backward pass and optimize
        loss.backward()
        optimizer_binary.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader_binary):.4f}")


# Validate the binary classification task
feature_extractor.eval()
target_model_binary.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, targets in val_loader_binary:
        data, targets = data.to(device), targets.to(device)
        features = feature_extractor(data)
        outputs = target_model_binary(features)
        predicted = (outputs > 0.5).float()  # Apply threshold for binary classification
        total += targets.size(0)
        correct += (predicted.view(-1) == targets).sum().item()

val_accuracy_binary = correct * 100 / total
print(f'Validation Accuracy (Binary Task): {val_accuracy_binary:.2f}%')



Epoch 1 - Binary Task: 100%|██████████| 625/625 [00:17<00:00, 36.74it/s]


Epoch [1/10], Loss: 0.2506


Epoch 2 - Binary Task: 100%|██████████| 625/625 [00:16<00:00, 38.04it/s]


Epoch [2/10], Loss: 0.0898


Epoch 3 - Binary Task: 100%|██████████| 625/625 [00:15<00:00, 41.42it/s]


Epoch [3/10], Loss: 0.0525


Epoch 4 - Binary Task: 100%|██████████| 625/625 [00:15<00:00, 40.52it/s]


Epoch [4/10], Loss: 0.0348


Epoch 5 - Binary Task: 100%|██████████| 625/625 [00:14<00:00, 42.03it/s]


Epoch [5/10], Loss: 0.0257


Epoch 6 - Binary Task: 100%|██████████| 625/625 [00:14<00:00, 42.53it/s]


Epoch [6/10], Loss: 0.0204


Epoch 7 - Binary Task: 100%|██████████| 625/625 [00:14<00:00, 43.03it/s]


Epoch [7/10], Loss: 0.0154


Epoch 8 - Binary Task: 100%|██████████| 625/625 [00:14<00:00, 43.13it/s]


Epoch [8/10], Loss: 0.0162


Epoch 9 - Binary Task: 100%|██████████| 625/625 [00:14<00:00, 42.51it/s]


Epoch [9/10], Loss: 0.0144


Epoch 10 - Binary Task: 100%|██████████| 625/625 [00:14<00:00, 42.36it/s]


Epoch [10/10], Loss: 0.0117
Validation Accuracy (Binary Task): 96.99%


Test feature extractor using multiclass classification task

In [13]:
# Test the feature extractor on the multiclass classification task
feature_extractor.eval()  # Feature extractor is frozen (not updated)
different_task_model.train()  # Only train the different task model

optimizer_multiclass = optim.Adam(different_task_model.parameters(), lr=1e-4)

# Train the DifferentTaskModel on the features for the 10-class task
for epoch in range(num_epochs):
    running_loss = 0
    for data, targets in tqdm(train_loader_multiclass, desc=f"Epoch {epoch+1} - Multiclass Task"):
        data, targets = data.to(device), targets.to(device)

        # Zero the parameter gradients
        optimizer_multiclass.zero_grad()

        # Forward pass through the feature extractor (frozen) and different task model
        with torch.no_grad():  # Do not update the feature extractor
            features = feature_extractor(data)
        outputs = different_task_model(features)

        # Compute the loss
        loss = criterion_multiclass(outputs, targets)

        # Backward pass and optimize (only updates the different task model)
        loss.backward()
        optimizer_multiclass.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader_multiclass):.4f}")

# Evaluate the DifferentTaskModel on the test data
different_task_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, targets in test_loader_multiclass:
        data, targets = data.to(device), targets.to(device)
        features = feature_extractor(data)  # Features from the frozen feature extractor
        outputs = different_task_model(features)
        _, predicted = torch.max(outputs, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

test_accuracy_multiclass = correct * 100 / total
print(f'Test Accuracy (Multiclass Task): {test_accuracy_multiclass:.2f}%')

Epoch 1 - Multiclass Task: 100%|██████████| 625/625 [00:14<00:00, 42.63it/s]


Epoch [1/10], Loss: 1.6661


Epoch 2 - Multiclass Task: 100%|██████████| 625/625 [00:14<00:00, 42.82it/s]


Epoch [2/10], Loss: 1.4302


Epoch 3 - Multiclass Task: 100%|██████████| 625/625 [00:14<00:00, 42.69it/s]


Epoch [3/10], Loss: 1.3073


Epoch 4 - Multiclass Task: 100%|██████████| 625/625 [00:15<00:00, 41.14it/s]


Epoch [4/10], Loss: 1.2266


Epoch 5 - Multiclass Task: 100%|██████████| 625/625 [00:16<00:00, 38.43it/s]


Epoch [5/10], Loss: 1.1742


Epoch 6 - Multiclass Task: 100%|██████████| 625/625 [00:14<00:00, 42.30it/s]


Epoch [6/10], Loss: 1.1368


Epoch 7 - Multiclass Task: 100%|██████████| 625/625 [00:14<00:00, 42.26it/s]


Epoch [7/10], Loss: 1.1097


Epoch 8 - Multiclass Task: 100%|██████████| 625/625 [00:15<00:00, 41.04it/s]


Epoch [8/10], Loss: 1.0913


Epoch 9 - Multiclass Task: 100%|██████████| 625/625 [00:14<00:00, 41.83it/s]


Epoch [9/10], Loss: 1.0763


Epoch 10 - Multiclass Task: 100%|██████████| 625/625 [00:15<00:00, 41.58it/s]


Epoch [10/10], Loss: 1.0640
Test Accuracy (Multiclass Task): 56.13%


In [14]:
# Test the binary classification task on the test set
correct = 0
total = 0
with torch.no_grad():
    for data, targets in test_loader_binary:
        data, targets = data.to(device), targets.to(device)
        features = feature_extractor(data)
        outputs = target_model_binary(features)
        predicted = (outputs > 0.5).float()  # Apply threshold for binary classification
        total += targets.size(0)
        correct += (predicted.view(-1) == targets).sum().item()

test_accuracy_binary = correct * 100 / total
print(f'Test Accuracy (Binary Task): {test_accuracy_binary:.2f}%')

Test Accuracy (Binary Task): 96.59%


In [17]:
class EnhancedDifferentTaskModel(nn.Module):
    def __init__(self):
        super(EnhancedDifferentTaskModel, self).__init__()
        # Fully connected layers with an additional layer
        self.fc1 = nn.Linear(256, 512)  # Input is 256, as provided by the feature extractor
        self.fc2 = nn.Linear(512, 256)  # New additional layer
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 10)  # Output layer for 10 classes

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))  # New additional layer
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [18]:
# Instantiate the enhanced multiclass model and move it to the device
enhanced_multiclass_model = EnhancedDifferentTaskModel().to(device)

# Use Adam optimizer for the enhanced multiclass model
optimizer_multiclass = optim.Adam(enhanced_multiclass_model.parameters(), lr=1e-4)

# Use CrossEntropyLoss for multiclass classification
criterion_multiclass = nn.CrossEntropyLoss()

# Number of epochs to train the model
num_epochs = 10

# Training loop for multiclass classification
for epoch in range(num_epochs):
    enhanced_multiclass_model.train()

    running_loss = 0.0

    for data, targets in tqdm(train_loader_multiclass, desc=f"Epoch {epoch+1} - Multiclass Task"):
        data, targets = data.to(device), targets.to(device)

        # Zero the parameter gradients
        optimizer_multiclass.zero_grad()

        # Forward pass
        features = feature_extractor(data)  # Use the feature extractor to get the features
        outputs = enhanced_multiclass_model(features)

        # Compute the loss
        loss = criterion_multiclass(outputs, targets)

        # Backward pass and optimize
        loss.backward()
        optimizer_multiclass.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader_multiclass):.4f}")

# Validate the enhanced multiclass classification task
enhanced_multiclass_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, targets in val_loader_multiclass:
        data, targets = data.to(device), targets.to(device)
        features = feature_extractor(data)
        outputs = enhanced_multiclass_model(features)
        _, predicted = torch.max(outputs, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

val_accuracy_multiclass = correct * 100 / total
print(f'Validation Accuracy (Multiclass Task): {val_accuracy_multiclass:.2f}%')

# Test the enhanced multiclass classification task on the test set
correct = 0
total = 0
with torch.no_grad():
    for data, targets in test_loader_multiclass:
        data, targets = data.to(device), targets.to(device)
        features = feature_extractor(data)
        outputs = enhanced_multiclass_model(features)
        _, predicted = torch.max(outputs, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

test_accuracy_multiclass = correct * 100 / total
print(f'Test Accuracy (Multiclass Task): {test_accuracy_multiclass:.2f}%')

Epoch 1 - Multiclass Task: 100%|██████████| 625/625 [00:21<00:00, 28.72it/s]


Epoch [1/10], Loss: 1.5625


Epoch 2 - Multiclass Task: 100%|██████████| 625/625 [00:20<00:00, 30.17it/s]


Epoch [2/10], Loss: 1.3115


Epoch 3 - Multiclass Task: 100%|██████████| 625/625 [00:21<00:00, 29.28it/s]


Epoch [3/10], Loss: 1.1788


Epoch 4 - Multiclass Task: 100%|██████████| 625/625 [00:21<00:00, 29.49it/s]


Epoch [4/10], Loss: 1.1203


Epoch 5 - Multiclass Task: 100%|██████████| 625/625 [00:20<00:00, 29.89it/s]


Epoch [5/10], Loss: 1.0939


Epoch 6 - Multiclass Task: 100%|██████████| 625/625 [00:21<00:00, 29.49it/s]


Epoch [6/10], Loss: 1.0720


Epoch 7 - Multiclass Task: 100%|██████████| 625/625 [00:21<00:00, 29.23it/s]


Epoch [7/10], Loss: 1.0584


Epoch 8 - Multiclass Task: 100%|██████████| 625/625 [00:21<00:00, 29.56it/s]


Epoch [8/10], Loss: 1.0428


Epoch 9 - Multiclass Task: 100%|██████████| 625/625 [00:20<00:00, 30.23it/s]


Epoch [9/10], Loss: 1.0304


Epoch 10 - Multiclass Task: 100%|██████████| 625/625 [00:21<00:00, 29.48it/s]


Epoch [10/10], Loss: 1.0202
Validation Accuracy (Multiclass Task): 57.40%
Test Accuracy (Multiclass Task): 57.79%
