In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from src.style import *
from config import *

# Generate 50 random 2D points in range [-1,1] with labels 1, 2, or 3
num_points = 60
seed = 13
torch.manual_seed(seed)
X = 2 * torch.rand((num_points, 2)) - 1  # Range [-1,1]
y = torch.randint(0, 3, (num_points,))

# Define a simple 1-layer NN with 100 hidden neurons
class SimpleNN(nn.Module):
    def __init__(self,hidden=100):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(2, hidden)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden , 3)  # 3 output classes
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)



In [None]:

# Initialize model, loss, and optimizer
model = SimpleNN(hidden=150)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 4000
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()
    
    acc = (torch.argmax(outputs, dim=1) == y).float().mean()
print(acc)

In [None]:
n_repeats = 5

In [None]:
# Select 60% of the dataset for transfer learning
num_transfer = int(0.6 * num_points)
rand_idx = torch.randperm(num_points)
indices = rand_idx[:num_transfer]
indices_test = rand_idx[num_transfer:]
X_transfer_train = X[indices]
y_transfer_train = y[indices]
X_transfer_test = X[indices_test]
y_transfer_test = y[indices_test]
random_accuracies = []
for i in range(n_repeats):
    random_model_1 = SimpleNN(hidden=150)
    student_optimizer = optim.Adam(random_model_1.parameters(), lr=0.001)
    student_criterion = nn.CrossEntropyLoss()

    # Train the student model
    num_epochs = 4000
    for epoch in range(num_epochs):
        student_optimizer.zero_grad()
        student_outputs = random_model_1(X_transfer_train)
        student_loss = student_criterion(student_outputs, y_transfer_train)
        student_loss.backward()
        student_optimizer.step()
        
        acc = torch.mean((torch.argmax(student_outputs, dim=1) == y_transfer_train).float())
    test_acc = torch.mean((torch.argmax(random_model_1(X_transfer_test), dim=1) == y_transfer_test).float())
    print(acc, test_acc)
    random_accuracies.append(test_acc)

# Generate soft labels from the trained model
with torch.no_grad():
    softlabels = F.softmax(model(X_transfer_train )/ 10.0, dim=-1)

student_1_accuracies = []
for i in range(n_repeats):
    # Define a new student model
    student_model_1 = SimpleNN(hidden=150)
    student_optimizer = optim.Adam(student_model_1.parameters(), lr=0.001)
    student_criterion = nn.CrossEntropyLoss()

    # Train the student model
    num_epochs = 4000
    for epoch in range(num_epochs):
        student_optimizer.zero_grad()
        student_outputs = student_model_1(X_transfer_train)
        student_loss = student_criterion(student_outputs, softlabels)
        student_loss.backward()
        student_optimizer.step()
        
        acc = torch.mean((torch.argmax(student_outputs, dim=1) == y_transfer_train).float())
    test_acc = torch.mean((torch.argmax(student_model_1(X_transfer_test), dim=1) == y_transfer_test).float())
    print(acc, test_acc)
    student_1_accuracies.append(test_acc)

In [None]:
models_3_label = {
    'teacher': model,
    'student': student_model_1,
    'random': random_model_1,
    'colors': {
    0: c_medium_contrast['dark-blue'],
    1: c_medium_contrast['dark-red'],
    2: c_medium_contrast['dark-yellow']
},
    'contour_colors':  [c_medium_contrast['light-blue'], c_medium_contrast['light-red'], c_medium_contrast['light-yellow']],
    'X': X,
    'y': y,
    'X_transfer_train': X_transfer_train,
    'y_transfer_train': y_transfer_train,
    'X_transfer_test': X_transfer_test,
    'y_transfer_test': y_transfer_test,
    'student_accuracies': student_1_accuracies,
    'random_accuracies': random_accuracies,
}

In [None]:
def compare_contours(ax, xx, yy, grid, model1,model2,colors,m1_color,m2_color):
    ax = plt.gca()
    scales = np.linspace(-0.5,len(colors)-0.5, len(colors)+1)
    print(scales)
    preds = model1(grid).detach().numpy()
    preds = np.argmax(preds, axis=1).reshape(xx.shape)
    ax.contour(xx, yy, preds, levels=scales, colors=m1_color)
    preds = model2(grid).detach().numpy()
    preds_student = np.argmax(preds, axis=1).reshape(xx.shape)
    ax.contour(xx, yy, preds_student, levels=scales,colors=m2_color)

    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('off')
    ax.set_aspect('equal', adjustable='box')

def plot_contour_and_points(ax, xx, yy, grid, model, colors, contour_colors, X, y):
    preds = model(grid).detach().numpy()
    preds = np.argmax(preds, axis=1).reshape(xx.shape)
    scales = np.linspace(-0.5,len(colors)-0.5, len(colors)+1)
    ax.contourf(xx, yy, preds, levels=scales, colors=contour_colors, alpha=0.3)
    ax.scatter(X[:, 0], X[:, 1], c=[colors[c.item()] for c in y],  edgecolors='white',marker='o',s=100)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('off')
    ax.set_aspect('equal', adjustable='box')
    
def plot_correct_wrong(ax, X_transfer_test, y_transfer_test, model,name, accuracies):
    correct_samples = (model(X_transfer_test).argmax(dim=1) == y_transfer_test)
    ax.scatter(X_transfer_test[correct_samples, 0], X_transfer_test[correct_samples, 1], 
                edgecolors='green',marker='o',s=100, facecolors='none')
    ax.scatter(X_transfer_test[~correct_samples, 0], X_transfer_test[~correct_samples, 1], 
                edgecolors='red',marker='x',s=50, facecolors='red')
    ax.set_title(f'{name}\n${{{int(correct_samples.float().mean().item()*100)}}}$% test acc. [avg. ${int(np.mean(accuracies)*100)}\pm{(np.std(accuracies)*100/np.sqrt(len(accuracies))):.1f}\%$]')


In [None]:
def plot_summary(axes,models_3_label):
    xx, yy = np.meshgrid(np.linspace(-1, 1, 100), np.linspace(-1, 1, 100))
    grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)

    

    ax = axes[0]

    plot_contour_and_points(ax, xx, yy, grid, 
                            models_3_label['teacher'], models_3_label['colors'], 
                            models_3_label['contour_colors'], 
                            models_3_label['X'], 
                            models_3_label['y'])
    ax.set_title('teacher\n100% train acc.')
    ax.set_aspect('equal', adjustable='box')
    
    ax = axes[1]

    plot_contour_and_points(ax, xx, yy, grid,
                            models_3_label['random'], models_3_label['colors'], 
                            models_3_label['contour_colors'], 
                            models_3_label['X_transfer_train'], 
                            models_3_label['y_transfer_train'])
    plot_correct_wrong(ax, models_3_label['X_transfer_test'],  models_3_label['y_transfer_test'], models_3_label['random'], name='independent',accuracies=models_3_label['random_accuracies'])


    ax = axes[2]

    plot_contour_and_points(ax, xx, yy, grid,
                            models_3_label['student'], models_3_label['colors'], 
                            models_3_label['contour_colors'], 
                            models_3_label['X_transfer_train'], 
                            models_3_label['y_transfer_train'])

    plot_correct_wrong(ax, models_3_label['X_transfer_test'],  models_3_label['y_transfer_test'], models_3_label['student'], name='student',accuracies=models_3_label['student_accuracies'])

    ax = axes[3]
    ax.set_title('decision boundaries\nteacher A vs. student B')

    compare_contours(ax, xx, yy, grid,
                        models_3_label['teacher'], models_3_label['student'],
                        models_3_label['colors'], 
                        'black', 
                        'tab:orange')
    
    

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(12*1.5, 3*1.5))
plot_summary(axes, models_3_label)

for ax in zip(axes,"ABCD"):
    ax[0].text(-1, 1.3, f'({ax[1]})', fontsize=14, color='black', ha='left', va='top', fontweight='bold')
    
plt.savefig(FIGURE_DIR / 'transfer_learning_3_labels.png', bbox_inches='tight', dpi=300)
