In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [3]:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
batch_size = 64
device = "cuda:0"

In [5]:
class ResNet50(nn.Module):
    def __init__(self, hidden_units, out_size):
        super(ResNet50, self).__init__()
        self.resnet50 = models.resnet50(pretrained=True)
        self.resnet50.fc_layer = nn.Linear(2048, hidden_units)
        self.resnet50.classifier = nn.Linear(hidden_units, out_size)

    def forward(self, x):
        features = self.resnet50.conv1(x)
        features = self.resnet50.bn1(features)
        features = self.resnet50.relu(features)
        features = self.resnet50.maxpool(features)
        features = self.resnet50.layer1(features)
        features = self.resnet50.layer2(features)
        features = self.resnet50.layer3(features)
        fmaps_b4 = self.resnet50.layer4(features)
        out = F.adaptive_avg_pool2d(fmaps_b4, (1, 1)).view(fmaps_b4.size(0), -1)
        out = self.resnet50.fc_layer(out)
        out = self.resnet50.classifier(out)
        return out


In [22]:
model1 = ResNet50(hidden_units=128, out_size=2).to(device)

In [7]:
class_indices = [0,1]
train_dataset_1 = torch.utils.data.Subset(train_dataset, [i for i in range(len(train_dataset)) if train_dataset.targets[i] in class_indices])
test_dataset_1 = torch.utils.data.Subset(train_dataset, [i for i in range(len(train_dataset)) if train_dataset.targets[i] in class_indices])
train_loader_1 = torch.utils.data.DataLoader(train_dataset_1, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader_1 = torch.utils.data.DataLoader(test_dataset_1, batch_size=batch_size, shuffle=False, num_workers=2)

In [8]:
class_indices = [2,3]
train_dataset_2 = torch.utils.data.Subset(train_dataset, [i for i in range(len(train_dataset)) if train_dataset.targets[i] in class_indices])
test_dataset_2 = torch.utils.data.Subset(train_dataset, [i for i in range(len(train_dataset)) if train_dataset.targets[i] in class_indices])
train_loader_2 = torch.utils.data.DataLoader(train_dataset_2, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader_2 = torch.utils.data.DataLoader(test_dataset_2, batch_size=batch_size, shuffle=False, num_workers=2)

In [23]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model1.parameters(), lr=0.001)
num_epochs = 10

In [24]:
for epoch in range(num_epochs):
    model1.train()
    running_loss = 0.0

    for inputs, labels in train_loader_1:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model1(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader_1)
    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss: {epoch_loss:.4f}")

Epoch [1/10] Loss: 0.1851
Epoch [2/10] Loss: 0.0880
Epoch [3/10] Loss: 0.0532
Epoch [4/10] Loss: 0.0369
Epoch [5/10] Loss: 0.0402
Epoch [6/10] Loss: 0.0311
Epoch [7/10] Loss: 0.1527
Epoch [8/10] Loss: 0.0619
Epoch [9/10] Loss: 0.0357
Epoch [10/10] Loss: 0.0243


In [25]:
correct_predictions = 0
total_samples = 0

with torch.no_grad():
    for images, labels in test_loader_1:
        images, labels = images.to(device), labels.to(device)
        outputs = model1(images)
        _, predicted = torch.max(outputs, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

accuracy = 100 * correct_predictions / total_samples
print(f'Accuracy : {accuracy:.2f}%')

Accuracy : 99.86%


In [26]:
model2 = ResNet50(hidden_units=128, out_size=4).to(device)

In [27]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model2.parameters(), lr=0.001)
num_epochs = 15

In [28]:
model1.eval()
kl_criterion = nn.KLDivLoss(reduction='batchmean').to(device)
ce_criterion = nn.CrossEntropyLoss().to(device)
alpha = 0.5

for epoch in range(num_epochs):
    model2.train()

    running_loss = 0.0

    for inputs, labels in train_loader_2:
        inputs, labels = inputs.to(device), labels.to(device)
        labels = labels % 2
        optimizer.zero_grad()
        with torch.no_grad():
            teacher_outputs = model1(inputs)
            zeros_tensor = torch.zeros_like(teacher_outputs)
            teacher_outputs = torch.cat((teacher_outputs, zeros_tensor), dim=1)

        outputs = model2(inputs)

        kl_loss = kl_criterion(torch.log_softmax(outputs, dim=1), torch.softmax(teacher_outputs, dim=1))
        ce_loss = ce_criterion(outputs, labels)
        loss = alpha * ce_loss + (1-alpha) * kl_loss
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader_2)
    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss: {epoch_loss:.4f}")

Epoch [1/15] Loss: 0.5745
Epoch [2/15] Loss: 0.5254
Epoch [3/15] Loss: 0.5015
Epoch [4/15] Loss: 0.4581
Epoch [5/15] Loss: 0.4352
Epoch [6/15] Loss: 0.4179
Epoch [7/15] Loss: 0.4073
Epoch [8/15] Loss: 0.4064
Epoch [9/15] Loss: 0.3956
Epoch [10/15] Loss: 0.3856
Epoch [11/15] Loss: 0.3821
Epoch [12/15] Loss: 0.3771
Epoch [13/15] Loss: 0.3699
Epoch [14/15] Loss: 0.3707
Epoch [15/15] Loss: 0.3706


In [29]:
correct_predictions = 0
total_samples = 0

with torch.no_grad():
    for images, labels in test_loader_2:
        images, labels = images.cuda(), labels.cuda()
        outputs = model2(images)
        _, predicted = torch.max(outputs, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels % 2).sum().item()

accuracy = 100 * correct_predictions / total_samples
print(f'Accuracy : {accuracy:.2f}%')

Accuracy : 95.65%


In [30]:
correct_predictions = 0
total_samples = 0

with torch.no_grad():
    for images, labels in test_loader_1:
        images, labels = images.cuda(), labels.cuda()
        outputs = model2(images)
        _, predicted = torch.max(outputs, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

accuracy = 100 * correct_predictions / total_samples
print(f'Accuracy : {accuracy:.2f}%')

Accuracy : 78.44%
