In [1]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import wandb
from utils import *
from networks import *

In [2]:
# Define transforms for the training and test sets
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the datasets
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.95,0.05], generator=torch.Generator().manual_seed(872))

# Create DataLoader for batching
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(872))
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(872))
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model = TeacherNetwork().to(device)
optimizer = optim.SGD(teacher_model.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-5)
num_epochs = 20

wandb.init(
    project="Renyi_Divergence_MNIST",
    name = "Teacher Model",
    config={}
)

print("Starting training...")

for epoch in range(num_epochs):
    print(f'---Epoch {epoch+1}---')
    
    train_loss, train_accuracy, size = 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = teacher_model(data)
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    teacher_model.is_training = False
    
    val_loss, val_accuracy, size = 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = teacher_model(data)
            val_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy = val_loss / size, val_accuracy / size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    wandb.log({"train_ce_loss": train_loss,
               "val_ce_loss": val_loss,
               "train_accuracy": 100*train_accuracy,
               "val_accuracy": 100*val_accuracy,
               "epoch": epoch
               }
    )
    
    teacher_model.is_training = True

teacher_model.is_training = False
test_loss, test_accuracy, size = 0, 0, 0

for i, (data, targets) in enumerate(test_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = teacher_model(data)
            test_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]

test_loss, test_accuracy = test_loss / size, test_accuracy / size

print(f'Training finished.')
print(f'Average Validation Loss: {test_loss:.4f}\t Validation Accuracy: {100*test_accuracy:.2f}%')

wandb.finish()

wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: martingratzer (martingratzer-organization). Use `wandb login --relogin` to force relogin


Starting training...
---Epoch 1---
Average Train Loss: 0.5390 	 	 Train Accuracy: 82.60%
Average Validation Loss: 0.2158 	 Validation Accuracy: 93.47%
---Epoch 2---
Average Train Loss: 0.2882 	 	 Train Accuracy: 90.90%
Average Validation Loss: 0.1445 	 Validation Accuracy: 95.40%
---Epoch 3---
Average Train Loss: 0.2284 	 	 Train Accuracy: 92.86%
Average Validation Loss: 0.1337 	 Validation Accuracy: 95.60%
---Epoch 4---
Average Train Loss: 0.2021 	 	 Train Accuracy: 93.74%
Average Validation Loss: 0.1054 	 Validation Accuracy: 97.13%
---Epoch 5---
Average Train Loss: 0.1805 	 	 Train Accuracy: 94.42%
Average Validation Loss: 0.1165 	 Validation Accuracy: 96.37%
---Epoch 6---
Average Train Loss: 0.1662 	 	 Train Accuracy: 94.91%
Average Validation Loss: 0.0982 	 Validation Accuracy: 97.23%
---Epoch 7---
Average Train Loss: 0.1550 	 	 Train Accuracy: 95.14%
Average Validation Loss: 0.0848 	 Validation Accuracy: 97.60%
---Epoch 8---
Average Train Loss: 0.1485 	 	 Train Accuracy: 95.33%
A

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
train_accuracy,▁▅▆▆▇▇▇▇▇▇▇█████████
train_ce_loss,█▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▄▆▅▆▇▇▇▇▇▇▇███████
val_ce_loss,█▅▄▃▄▃▂▂▂▂▂▁▂▁▁▁▁▁▁▁

0,1
epoch,19.0
train_accuracy,97.17544
train_ce_loss,0.08879
val_accuracy,98.03333
val_ce_loss,0.065


In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vanilla_model = StudentNetwork().to(device)
optimizer = optim.SGD(vanilla_model.parameters(), lr=0.02, momentum=0.7, weight_decay=1e-4)
num_epochs = 10

wandb.init(
    project="Renyi_Divergence_MNIST",
    name = "Vanilla Model",
    config={}
)

print("Starting training...")

for epoch in range(num_epochs):
    print(f'---Epoch {epoch+1}---')
    
    train_loss, train_accuracy, size = 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = vanilla_model(data)
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    vanilla_model.is_training = False
    
    val_loss, val_accuracy, size = 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = vanilla_model(data)
            val_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy = val_loss / size, val_accuracy / size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    wandb.log({"train_ce_loss": train_loss,
               "val_ce_loss": val_loss,
               "train_accuracy": 100*train_accuracy,
               "val_accuracy": 100*val_accuracy,
               "epoch": epoch
               }
    )
    
    vanilla_model.is_training = True

vanilla_model.is_training = False
test_loss, test_accuracy, size = 0, 0, 0

for i, (data, targets) in enumerate(test_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = vanilla_model(data)
            test_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]

test_loss, test_accuracy = test_loss / size, test_accuracy / size

print(f'Training finished.')
print(f'Average Validation Loss: {test_loss:.4f}\t Validation Accuracy: {100*test_accuracy:.2f}%')

wandb.finish()


KeyboardInterrupt



In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = StudentNetwork().to(device)
optimizer = optim.SGD(student_model.parameters(), lr=0.045, momentum=0.65, weight_decay=1e-4)
num_epochs = 10
temperature = 2.2
beta = 0.55
alpha = 1

wandb.init(
    project="Renyi_Divergence_MNIST",
    name = "Student Model",
    config={
        "beta": beta,
        "temperature": temperature,
        "alpha": alpha
    }
)

print("Starting training...")

for epoch in range(num_epochs):
    print(f'---Epoch {epoch+1}---')
    
    train_loss, train_accuracy, size, train_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = student_model(data)
        with torch.no_grad():
            teacher_outputs = teacher_model(data)
        loss = (1-beta) * nn.CrossEntropyLoss()(outputs, targets) + beta * RenyiDivergence(alpha=alpha)(outputs/temperature, teacher_outputs/temperature) * (temperature**2) / alpha
        CE_loss = nn.CrossEntropyLoss()(outputs, targets)     
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_CE_loss += CE_loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_CE_loss = train_CE_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Average Train CE Loss: {train_CE_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    student_model.is_training = False
    
    val_loss, val_accuracy, size, val_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = student_model(data)
            teacher_pred = teacher_model(data)
            val_loss += ((1-beta) * nn.CrossEntropyLoss()(pred, targets) + beta * RenyiDivergence(alpha=alpha)(pred/temperature, teacher_pred/temperature) * (temperature**2) / alpha) * targets.shape[0]
            val_CE_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy, val_CE_loss = val_loss / size, val_accuracy / size, val_CE_loss/size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Average Validation CE Loss: {val_CE_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    wandb.log({"train_loss": train_loss,
               "train_ce_loss": train_CE_loss,
               "val_loss": val_loss,
               "val_ce_loss": val_CE_loss,
               "train_accuracy": 100*train_accuracy,
               "val_accuracy": 100*val_accuracy,
               "epoch": epoch
               }
    )
    
    student_model.is_training = True

student_model.is_training = False
test_loss, test_accuracy, size = 0, 0, 0

for i, (data, targets) in enumerate(test_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = student_model(data)
            test_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]

test_loss, test_accuracy = test_loss / size, test_accuracy / size

print(f'Training finished.')
print(f'Average Validation Loss: {test_loss:.4f}\t Validation Accuracy: {100*test_accuracy:.2f}%')

wandb.finish()

0,1
epoch,▁
train_accuracy,▁
train_ce_loss,▁
val_accuracy,▁
val_ce_loss,▁

0,1
epoch,0.0
train_accuracy,82.58947
train_ce_loss,0.53911
val_accuracy,93.7
val_ce_loss,0.20564


Starting training...
---Epoch 1---
Average Train Loss: 0.5658 	 	 Average Train CE Loss: 0.3549 	 	 Train Accuracy: 89.16%
Average Validation Loss: 0.3301 	 Average Validation CE Loss: 0.2183 	 Validation Accuracy: 94.23%
---Epoch 2---


KeyboardInterrupt: 

In [None]:
def train(config=None,lr=None,momentum=None,weight_decay=None,temperature=None,beta=None,alpha=None,destilation=False,teacher_model=None):
    with wandb.init(config=config):
        config = wandb.config
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        #network = TeacherNetwork().to(device)
        network = StudentNetwork().to(device)
        epochs = 20
        
        train_loader, val_loader, test_loader = build_dataset()
        
        if lr is not None:
            config.lr = lr
        if momentum is not None:
            config.momentum = momentum
        if weight_decay is not None:
            config.weight_decay = weight_decay
        if temperature is not None:
            config.temperature = temperature
        if beta is not None:
            config.beta = beta
        if alpha is not None:
            config.alpha = alpha
        
        optimizer = optim.SGD(network.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)
        
        for epoch in range(epochs):
            train_epoch(epoch, network, train_loader, val_loader, optimizer, device, destilation, teacher_model, config)

In [None]:
def build_dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.95,0.05], generator=torch.Generator().manual_seed(872))

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(872))
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(872))
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    return train_loader, val_loader, test_loader
    
def train_epoch(epoch, network, train_loader, val_loader, optimizer, device, destilation, teacher_model, config):
    train_loss, train_accuracy, size, train_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = network(data)
        if destilation:
            with torch.no_grad():
                teacher_outputs = teacher_model(data)
            loss = (1-config.beta) * nn.CrossEntropyLoss()(outputs, targets) + config.beta * RenyiDivergence(alpha=config.alpha)(outputs/config.temperature, teacher_outputs/config.temperature) * (config.temperature**2) / config.alpha
        else:
            loss = nn.CrossEntropyLoss()(outputs, targets)
        CE_loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_CE_loss += CE_loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    if destilation:
        train_loss,train_CE_loss,train_accuracy = train_loss/size, train_CE_loss/size, train_accuracy/size
        print(f'Average Train Loss: {train_loss:.4f} \t \t Average Train CE Loss: {train_CE_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    else:
        train_CE_loss,train_accuracy = train_CE_loss/size, train_accuracy/size
        print(f'Average Train CE Loss: {train_CE_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    network.is_training = False
    
    val_loss, val_accuracy, size, val_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = network(data)
            if destilation:
                teacher_pred = teacher_model(data)
            if destilation:
                teacher_pred = teacher_model(data)
                val_loss += ((1-config.beta) * nn.CrossEntropyLoss()(pred, targets) + config.beta * RenyiDivergence(alpha=config.alpha)(pred/config.temperature, teacher_pred/config.temperature) * (config.temperature**2) / config.alpha) * targets.shape[0]
            val_CE_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    if destilation:
        val_loss, val_accuracy, val_CE_loss = val_loss / size, val_accuracy / size, val_CE_loss/size
        print(f'Average Validation Loss: {val_loss:.4f} \t Average Validation CE Loss: {val_CE_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    else:
        val_accuracy, val_CE_loss = val_accuracy / size, val_CE_loss/size
        print(f'Average Validation CE Loss: {val_CE_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    if destilation:
        wandb.log({"train_loss": train_loss,
                   "train_ce_loss": train_CE_loss,
                   "val_loss": val_loss,
                   "val_ce_loss": val_CE_loss,
                   "train_accuracy": 100*train_accuracy,
                   "val_accuracy": 100*val_accuracy,
                   "epoch": epoch
                   }
        )
    else:
        wandb.log({"train_ce_loss": train_CE_loss,
                   "val_ce_loss": val_CE_loss,
                   "train_accuracy": 100*train_accuracy,
                   "val_accuracy": 100*val_accuracy,
                   "epoch": epoch
                   }
        )
    
    network.is_training = True

In [None]:
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'beta': {'distribution': 'uniform', 'min': 0.5, 'max': 1},
        'temperature': {'distribution': 'uniform', 'min': 2, 'max': 6},
        'lr': {'distribution': 'uniform', 'min': 0.01, 'max': 0.05},
        'momentum': {'distribution': 'uniform', 'min': 0.6, 'max': 0.9},
        'weight_decay': {'distribution': 'uniform', 'min': 1e-4, 'max': 1e-3},
    }
}

In [None]:
sweep_id = wandb.sweep(sweep_config, project="Renyi_Divergence_Sweep_Student")

In [None]:
wandb.agent(sweep_id, function=lambda: train(alpha=1,destilation=True,teacher_model=teacher_model), count=30)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = StudentNetwork().to(device)
optimizer = optim.SGD(student_model.parameters(), lr=0.045, momentum=0.65, weight_decay=1e-4)
num_epochs = 10
temperature = 2.2
beta = 0.5
alpha = 0.5
zeta = 0.5

wandb.init(
    project="Renyi_Divergence_MNIST",
    name = "Student Model using DKD",
    config={
        "beta": beta,
        "temperature": temperature,
        "alpha": alpha,
        "zeta": zeta
    }
)

print("Starting training...")

for epoch in range(num_epochs):
    print(f'---Epoch {epoch+1}---')
    
    train_loss, train_accuracy, size, train_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = student_model(data)
        with torch.no_grad():
            teacher_outputs = teacher_model(data)
        loss = (1-zeta) * nn.CrossEntropyLoss()(outputs, targets) + zeta * DKD(alpha=alpha,beta=beta)(outputs/temperature, teacher_outputs/temperature,targets) * (temperature**2)
        CE_loss = nn.CrossEntropyLoss()(outputs, targets)     
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_CE_loss += CE_loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_CE_loss = train_CE_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Average Train CE Loss: {train_CE_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    student_model.is_training = False
    
    val_loss, val_accuracy, size, val_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = student_model(data)
            teacher_pred = teacher_model(data)
            val_loss += ((1-zeta) * nn.CrossEntropyLoss()(pred, targets) + zeta * DKD(alpha=alpha,beta=beta)(pred/temperature, teacher_pred/temperature,targets) * (temperature**2)) * targets.shape[0]
            val_CE_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy, val_CE_loss = val_loss / size, val_accuracy / size, val_CE_loss/size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Average Validation CE Loss: {val_CE_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    wandb.log({"train_loss": train_loss,
               "train_ce_loss": train_CE_loss,
               "val_loss": val_loss,
               "val_ce_loss": val_CE_loss,
               "train_accuracy": 100*train_accuracy,
               "val_accuracy": 100*val_accuracy,
               "epoch": epoch
               }
    )
    
    student_model.is_training = True

student_model.is_training = False
test_loss, test_accuracy, size = 0, 0, 0

for i, (data, targets) in enumerate(test_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = student_model(data)
            test_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]

test_loss, test_accuracy = test_loss / size, test_accuracy / size

print(f'Training finished.')
print(f'Average Validation Loss: {test_loss:.4f}\t Validation Accuracy: {100*test_accuracy:.2f}%')

wandb.finish()

KeyboardInterrupt: 

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = StudentNetwork().to(device)
optimizer = optim.SGD(student_model.parameters(), lr=0.045, momentum=0.65, weight_decay=1e-4)
num_epochs = 10
temperature = 2.2
alpha=0.5
zeta = 0.5
delta=0.5
gamma=0.5

wandb.init(
    project="Renyi_Divergence_MNIST",
    name = "Student Model using GDKD",
    config={
        "temperature": temperature,
        "alpha": alpha,
        "zeta": zeta,
        "delta": delta,
        "gamma": gamma
    }
)

print("Starting training...")

for epoch in range(num_epochs):
    print(f'---Epoch {epoch+1}---')
    
    train_loss, train_accuracy, size, train_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = student_model(data)
        with torch.no_grad():
            teacher_outputs = teacher_model(data)
        loss = (1-zeta) * nn.CrossEntropyLoss()(outputs, targets) + zeta * GDKD(alpha=alpha,delta=delta,gamma=gamma)(outputs/temperature, teacher_outputs/temperature,targets) * (temperature**2) / alpha
        CE_loss = nn.CrossEntropyLoss()(outputs, targets)     
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_CE_loss += CE_loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_CE_loss = train_CE_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Average Train CE Loss: {train_CE_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    student_model.is_training = False
    
    val_loss, val_accuracy, size, val_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = student_model(data)
            teacher_pred = teacher_model(data)
            val_loss += ((1-zeta) * nn.CrossEntropyLoss()(pred, targets) + zeta * GDKD(alpha=alpha,delta=delta,gamma=gamma)(pred/temperature, teacher_pred/temperature,targets) * (temperature**2) / alpha) * targets.shape[0]
            val_CE_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy, val_CE_loss = val_loss / size, val_accuracy / size, val_CE_loss/size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Average Validation CE Loss: {val_CE_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    wandb.log({"train_loss": train_loss,
               "train_ce_loss": train_CE_loss,
               "val_loss": val_loss,
               "val_ce_loss": val_CE_loss,
               "train_accuracy": 100*train_accuracy,
               "val_accuracy": 100*val_accuracy,
               "epoch": epoch
               }
    )
    
    student_model.is_training = True

student_model.is_training = False
test_loss, test_accuracy, size = 0, 0, 0

for i, (data, targets) in enumerate(test_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = student_model(data)
            test_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]

test_loss, test_accuracy = test_loss / size, test_accuracy / size

print(f'Training finished.')
print(f'Average Validation Loss: {test_loss:.4f}\t Validation Accuracy: {100*test_accuracy:.2f}%')

wandb.finish()

Starting training...
---Epoch 1---
Average Train Loss: 0.8538 	 	 Average Train CE Loss: 0.3278 	 	 Train Accuracy: 90.01%
Average Validation Loss: 0.5063 	 Average Validation CE Loss: 0.1341 	 Validation Accuracy: 96.13%
---Epoch 2---
Average Train Loss: 0.3737 	 	 Average Train CE Loss: 0.1525 	 	 Train Accuracy: 95.48%
Average Validation Loss: 0.3292 	 Average Validation CE Loss: 0.1022 	 Validation Accuracy: 96.77%
---Epoch 3---
Average Train Loss: 0.3078 	 	 Average Train CE Loss: 0.1223 	 	 Train Accuracy: 96.28%
Average Validation Loss: 0.2746 	 Average Validation CE Loss: 0.0932 	 Validation Accuracy: 96.97%
---Epoch 4---
Average Train Loss: 0.2790 	 	 Average Train CE Loss: 0.1096 	 	 Train Accuracy: 96.65%
Average Validation Loss: 0.2928 	 Average Validation CE Loss: 0.1075 	 Validation Accuracy: 96.50%
---Epoch 5---
Average Train Loss: 0.2610 	 	 Average Train CE Loss: 0.1009 	 	 Train Accuracy: 96.85%
Average Validation Loss: 0.2133 	 Average Validation CE Loss: 0.0877 	 Va

0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_accuracy,▁▆▇▇▇█████
train_ce_loss,█▃▂▂▁▁▁▁▁▁
train_loss,█▃▂▂▁▁▁▁▁▁
val_accuracy,▁▄▅▃▆▇████
val_ce_loss,█▄▃▅▂▂▁▂▁▁
val_loss,█▄▃▃▂▂▂▁▁▁

0,1
epoch,9.0
train_accuracy,97.39649
train_ce_loss,0.08372
train_loss,0.22438
val_accuracy,97.73333
val_ce_loss,0.07737
val_loss,0.18496


In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = StudentNetwork().to(device)
optimizer = optim.SGD(student_model.parameters(), lr=0.045, momentum=0.65, weight_decay=1e-4)
num_epochs = 10
temperature = 2.2
alpha = 1.5
zeta = 0.5

wandb.init(
    project="Renyi_Divergence_MNIST",
    name = "Student Model using NKD",
    config={
        "temperature": temperature,
        "alpha": alpha,
        "zeta": zeta
    }
)

print("Starting training...")

for epoch in range(num_epochs):
    print(f'---Epoch {epoch+1}---')
    
    train_loss, train_accuracy, size, train_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = student_model(data)
        with torch.no_grad():
            teacher_outputs = teacher_model(data)
        loss = (1-zeta) * nn.CrossEntropyLoss()(outputs, targets) + zeta * NKD(alpha=alpha)(outputs/temperature, teacher_outputs/temperature,targets) * (temperature**2) / alpha
        CE_loss = nn.CrossEntropyLoss()(outputs, targets)     
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_CE_loss += CE_loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_CE_loss = train_CE_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Average Train CE Loss: {train_CE_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    student_model.is_training = False
    
    val_loss, val_accuracy, size, val_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = student_model(data)
            teacher_pred = teacher_model(data)
            val_loss += ((1-zeta) * nn.CrossEntropyLoss()(pred, targets) + zeta * NKD(alpha=alpha)(pred/temperature, teacher_pred/temperature,targets) * (temperature**2) / alpha) * targets.shape[0]
            val_CE_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy, val_CE_loss = val_loss / size, val_accuracy / size, val_CE_loss/size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Average Validation CE Loss: {val_CE_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    wandb.log({"train_loss": train_loss,
               "train_ce_loss": train_CE_loss,
               "val_loss": val_loss,
               "val_ce_loss": val_CE_loss,
               "train_accuracy": 100*train_accuracy,
               "val_accuracy": 100*val_accuracy,
               "epoch": epoch
               }
    )
    
    student_model.is_training = True

student_model.is_training = False
test_loss, test_accuracy, size = 0, 0, 0

for i, (data, targets) in enumerate(test_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = student_model(data)
            test_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]

test_loss, test_accuracy = test_loss / size, test_accuracy / size

print(f'Training finished.')
print(f'Average Validation Loss: {test_loss:.4f}\t Validation Accuracy: {100*test_accuracy:.2f}%')

wandb.finish()

Starting training...
---Epoch 1---
Average Train Loss: 5.1549 	 	 Average Train CE Loss: 0.3809 	 	 Train Accuracy: 88.74%
Average Validation Loss: 3.8960 	 Average Validation CE Loss: 0.1651 	 Validation Accuracy: 95.30%
---Epoch 2---
Average Train Loss: 4.2405 	 	 Average Train CE Loss: 0.1963 	 	 Train Accuracy: 94.62%
Average Validation Loss: 3.7661 	 Average Validation CE Loss: 0.1347 	 Validation Accuracy: 95.90%
---Epoch 3---
Average Train Loss: 4.0978 	 	 Average Train CE Loss: 0.1678 	 	 Train Accuracy: 95.35%
Average Validation Loss: 3.6441 	 Average Validation CE Loss: 0.1230 	 Validation Accuracy: 96.67%
---Epoch 4---
Average Train Loss: 4.0184 	 	 Average Train CE Loss: 0.1524 	 	 Train Accuracy: 95.75%
Average Validation Loss: 3.5770 	 Average Validation CE Loss: 0.1049 	 Validation Accuracy: 97.13%
---Epoch 5---
Average Train Loss: 3.9673 	 	 Average Train CE Loss: 0.1411 	 	 Train Accuracy: 96.04%
Average Validation Loss: 3.5788 	 Average Validation CE Loss: 0.1125 	 Va

wandb: ERROR Control-C detected -- Run data was not synced


KeyboardInterrupt: 

In [50]:
### Filtering out all 3s
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vanilla_model = StudentNetwork().to(device)
optimizer = optim.SGD(vanilla_model.parameters(), lr=0.02, momentum=0.7, weight_decay=1e-4)
num_epochs = 10

print("Starting training...")

for epoch in range(num_epochs):
    print(f'---Epoch {epoch+1}---')
    
    train_loss, train_accuracy, size = 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        mask = targets != 3
        data, targets = data[mask], targets[mask]
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = vanilla_model(data)
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    vanilla_model.is_training = False
    
    val_loss, val_accuracy, size = 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = vanilla_model(data)
            val_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy = val_loss / size, val_accuracy / size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    vanilla_model.is_training = True

vanilla_model.is_training = False
test_loss, test_accuracy, size = 0, 0, 0

for i, (data, targets) in enumerate(test_loader):
    data, targets = data.to(device), targets.to(device)
    with torch.no_grad():
        pred = vanilla_model(data)
        test_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
        test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
    size += targets.shape[0]

test_loss, test_accuracy = test_loss / size, test_accuracy / size

print(f'Training finished.')
print(f'Average Validation Loss: {test_loss:.4f}\t Validation Accuracy: {100*test_accuracy:.2f}%')

test_accuracy, size = 0, 0
for i, (data, targets) in enumerate(test_loader):
    mask = targets == 3
    data, targets = data[mask], targets[mask]
    data, targets = data.to(device), targets.to(device)
    with torch.no_grad():
        pred = vanilla_model(data)
        test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
    size += targets.shape[0]

print(f'Test Accuracy on 3s only: {100*test_accuracy/size:.2f}%')

Starting training...
---Epoch 1---
Average Train Loss: 0.3957 	 	 Train Accuracy: 87.95%
Average Validation Loss: 1.0994 	 Validation Accuracy: 83.53%
---Epoch 2---
Average Train Loss: 0.2107 	 	 Train Accuracy: 93.65%
Average Validation Loss: 1.0710 	 Validation Accuracy: 85.63%
---Epoch 3---
Average Train Loss: 0.1634 	 	 Train Accuracy: 95.08%
Average Validation Loss: 1.1186 	 Validation Accuracy: 86.63%
---Epoch 4---
Average Train Loss: 0.1373 	 	 Train Accuracy: 95.90%
Average Validation Loss: 1.1318 	 Validation Accuracy: 86.70%
---Epoch 5---
Average Train Loss: 0.1219 	 	 Train Accuracy: 96.30%
Average Validation Loss: 1.2347 	 Validation Accuracy: 87.07%
---Epoch 6---
Average Train Loss: 0.1067 	 	 Train Accuracy: 96.74%
Average Validation Loss: 1.2117 	 Validation Accuracy: 86.83%
---Epoch 7---
Average Train Loss: 0.1010 	 	 Train Accuracy: 96.90%
Average Validation Loss: 1.2052 	 Validation Accuracy: 86.97%
---Epoch 8---
Average Train Loss: 0.0912 	 	 Train Accuracy: 97.12%
A

In [51]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = StudentNetwork().to(device)
optimizer = optim.SGD(student_model.parameters(), lr=0.045, momentum=0.65, weight_decay=1e-4)
num_epochs = 10
temperature = 2.2
beta = 0.55
alpha = 1

print("Starting training...")

for epoch in range(num_epochs):
    print(f'---Epoch {epoch+1}---')
    
    train_loss, train_accuracy, size, train_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        mask = targets != 3
        data, targets = data[mask], targets[mask]
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = student_model(data)
        with torch.no_grad():
            teacher_outputs = teacher_model(data)
        loss = (1-beta) * nn.CrossEntropyLoss()(outputs, targets) + beta * RenyiDivergence(alpha=alpha)(outputs/temperature, teacher_outputs/temperature) * (temperature**2) / alpha
        CE_loss = nn.CrossEntropyLoss()(outputs, targets)     
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_CE_loss += CE_loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_CE_loss = train_CE_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Average Train CE Loss: {train_CE_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    student_model.is_training = False
    
    val_loss, val_accuracy, size, val_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = student_model(data)
            teacher_pred = teacher_model(data)
            val_loss += ((1-beta) * nn.CrossEntropyLoss()(pred, targets) + beta * RenyiDivergence(alpha=alpha)(pred/temperature, teacher_pred/temperature) * (temperature**2) / alpha) * targets.shape[0]
            val_CE_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy, val_CE_loss = val_loss / size, val_accuracy / size, val_CE_loss/size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Average Validation CE Loss: {val_CE_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    student_model.is_training = True

student_model.is_training = False
test_loss, test_accuracy, size = 0, 0, 0

for i, (data, targets) in enumerate(test_loader):
    data, targets = data.to(device), targets.to(device)
    with torch.no_grad():
        pred = student_model(data)
        test_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
        test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
    size += targets.shape[0]

test_loss, test_accuracy = test_loss / size, test_accuracy / size

print(f'Training finished.')
print(f'Average Validation Loss: {test_loss:.4f}\t Validation Accuracy: {100*test_accuracy:.2f}%')

test_accuracy, size = 0, 0
for i, (data, targets) in enumerate(test_loader):
    mask = targets == 3
    data, targets = data[mask], targets[mask]
    data, targets = data.to(device), targets.to(device)
    with torch.no_grad():
        pred = student_model(data)
        test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
    size += targets.shape[0]

print(f'Test Accuracy on 3s only: {100*test_accuracy/size:.2f}%')

Starting training...
---Epoch 1---
Average Train Loss: 0.7010 	 	 Average Train CE Loss: 0.3275 	 	 Train Accuracy: 90.54%
Average Validation Loss: 0.9244 	 Average Validation CE Loss: 0.5447 	 Validation Accuracy: 85.97%
---Epoch 2---
Average Train Loss: 0.2754 	 	 Average Train CE Loss: 0.1457 	 	 Train Accuracy: 95.71%
Average Validation Loss: 0.7525 	 Average Validation CE Loss: 0.4755 	 Validation Accuracy: 87.63%
---Epoch 3---
Average Train Loss: 0.2106 	 	 Average Train CE Loss: 0.1113 	 	 Train Accuracy: 96.68%
Average Validation Loss: 0.6263 	 Average Validation CE Loss: 0.3947 	 Validation Accuracy: 88.63%
---Epoch 4---
Average Train Loss: 0.1778 	 	 Average Train CE Loss: 0.0943 	 	 Train Accuracy: 97.20%
Average Validation Loss: 0.6247 	 Average Validation CE Loss: 0.3845 	 Validation Accuracy: 88.77%
---Epoch 5---
Average Train Loss: 0.1600 	 	 Average Train CE Loss: 0.0850 	 	 Train Accuracy: 97.37%
Average Validation Loss: 0.5398 	 Average Validation CE Loss: 0.3303 	 Va