In [None]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import wandb

In [None]:
# Example values for the probabilities
x = torch.tensor([[9/25,12/25,4/25],[2/7,1/4,1/5]],requires_grad=True)
y = torch.tensor([[1/3,1/3,1/3],[1/2,5/4,1/5]])
print(x)
print(y)

In [None]:
# Example of divergence being infinity
#x = torch.tensor([[-0.7753, -0.7236, -0.6072, -0.8284, -0.8085, -0.8923, -0.6853, -0.8315,-0.8302, -0.7934]])
#y = torch.tensor([[-9.7151e+00, -7.6585e+00, -7.3228e+00, -7.9473e-08, -1.0569e+01,-6.7095e+00, -1.1969e+01, -7.1102e+00, -5.1660e+00, -6.2902e+00]])
#print(x)
#print(y)

In [None]:
# Definition of Renyi Divergence with logits as input
class RenyiDivergence(nn.Module):
    def __init__(self, alpha=0.5):
        super(RenyiDivergence, self).__init__()
        self.alpha = alpha

    def forward(self, input_logits, target_logits):
        if self.alpha == 1.0: # KL Divergence
            log_sums = torch.logsumexp(input_logits, dim=1)-torch.logsumexp(target_logits, dim=1)
            
            exp_target_logits = torch.exp(target_logits)
            sum_target_exp = exp_target_logits.sum(dim=1, keepdim=True)
            
            loss = torch.sum((exp_target_logits / sum_target_exp) * (target_logits - input_logits + log_sums.unsqueeze(1)), dim=1)
            
            return loss.mean()
        
        target_exp = torch.exp(target_logits)
        input_exp = torch.exp(input_logits)
        
        target_power = target_exp.pow(self.alpha)
        input_power = input_exp.pow(1-self.alpha)
        
        target_sum_exp = torch.sum(target_exp,dim=1)
        input_sum_exp = torch.sum(input_exp,dim=1)
        
        loss = 1/(self.alpha-1) * (torch.log(torch.sum(target_power*input_power,dim=1)) - self.alpha*torch.log(target_sum_exp) + (self.alpha-1)*torch.log(input_sum_exp))
        
        return loss.mean()

In [None]:
# Test
x = torch.tensor([[2/3,1/3,1/3],[2/3,1/3,1/3]],requires_grad=True)
y = torch.tensor([[1/3,2/3,1/3],[1/3,2/3,1/3]])
output = RenyiDivergence(alpha=1)(x,y)
print(output)
output.backward()
print(x.grad)

In [None]:
## Preparation for plot creation
Q = torch.tensor([[1/3,2/3]])
Q = torch.log(Q)
lists = []
for alpha in [0.5,1,2,10]:
    list = []
    for p in range(501):
        P = torch.tensor([[p/500,1-p/500]])
        P = torch.log(P)
        divergence = RenyiDivergence(alpha=alpha)(Q,P)
        list.append(divergence)
    lists.append(list)
    
P = torch.tensor([[1/3,2/3]])
P = torch.log(P)
lists_2 = []
for alpha in [0.5,1,2,10]:
    list = []
    for q in range(501):
        Q = torch.tensor([[q/500,1-q/500]])
        Q = torch.log(Q)
        divergence = RenyiDivergence(alpha=alpha)(Q,P)
        list.append(divergence)
    lists_2.append(list)

In [None]:
## Ploting
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots()

ax.plot(np.array(range(501)) / 500, np.transpose(np.array(lists[0])), color='blue', linestyle='-', label=r'$\alpha = 1/2$')
ax.plot(np.array(range(501)) / 500, np.transpose(np.array(lists[1])), color='red', linestyle='--', label=r'$\alpha = 1$')
ax.plot(np.array(range(501)) / 500, np.transpose(np.array(lists[2])), color='orange', linestyle=':', label=r'$\alpha = 2$')
ax.plot(np.array(range(501)) / 500, np.transpose(np.array(lists[3])), color='green', linestyle='-.', label=r'$\alpha = 10$')

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax.annotate(r'$p$', xy=(1.05, 0), xytext=(10, -20), 
            textcoords='offset points', ha='center', va='center', fontsize=12)

ax.annotate(r'$D_\alpha(P \| Q)$', xy=(0, 1.05), xytext=(-40, 10), 
            textcoords='offset points', ha='center', va='center', fontsize=12)

ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')

ax.legend()

fig.savefig('Renyi_Divergence.png')

plt.show()


In [None]:
fig, ax = plt.subplots()

ax.plot(np.array(range(501)) / 500, np.transpose(np.array(lists_2[0])), color='blue', linestyle='-', label=r'$\alpha = 1/2$')
ax.plot(np.array(range(501)) / 500, np.transpose(np.array(lists_2[1])), color='red', linestyle='--', label=r'$\alpha = 1$')
ax.plot(np.array(range(501)) / 500, np.transpose(np.array(lists_2[2])), color='orange', linestyle=':', label=r'$\alpha = 2$')
ax.plot(np.array(range(501)) / 500, np.transpose(np.array(lists_2[3])), color='green', linestyle='-.', label=r'$\alpha = 10$')

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax.annotate(r'$q$', xy=(1.05, 0), xytext=(10, -20), 
            textcoords='offset points', ha='center', va='center', fontsize=12)

ax.annotate(r'$D_\alpha(P \| Q)$', xy=(0, 5.45), xytext=(-40, 10), 
            textcoords='offset points', ha='center', va='center', fontsize=12)

ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')

ax.legend()

fig.savefig('Renyi_Divergence_2.png')

plt.show()

In [None]:
# Define transforms for the training and test sets
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the datasets
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.95,0.05], generator=torch.Generator().manual_seed(872))

# Create DataLoader for batching
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(872))
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(872))
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [None]:
class TeacherNetwork(nn.Module):
    def __init__(self):
        super(TeacherNetwork, self).__init__()
        self.fc1 = nn.Linear(28*28, 1024) 
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
        self.dropout_input = 0.1
        self.dropout_hidden = 0.2
        self.is_training = True

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the image
        x = F.dropout(x, p=self.dropout_input, training=self.is_training)
        x = F.dropout(F.relu(self.fc1(x)), p=self.dropout_hidden, training=self.is_training)
        x = F.dropout(F.relu(self.fc2(x)), p=self.dropout_hidden, training=self.is_training)
        x = self.fc3(x)
        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model = TeacherNetwork().to(device)
optimizer = optim.SGD(teacher_model.parameters(), lr=0.05, momentum=0.5, weight_decay=3e-4)
num_epochs = 10

wandb.init(
    project="Renyi_Divergence_MNIST",
    name = "Teacher Model",
    config={}
)

print("Starting training...")

for epoch in range(num_epochs):
    print(f'---Epoch {epoch+1}---')
    
    train_loss, train_accuracy, size = 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = teacher_model(data)
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    teacher_model.is_training = False
    
    val_loss, val_accuracy, size = 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = teacher_model(data)
            val_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy = val_loss / size, val_accuracy / size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    wandb.log({"train_ce_loss": train_loss,
               "val_ce_loss": val_loss,
               "train_accuracy": 100*train_accuracy,
               "val_accuracy": 100*val_accuracy,
               "epoch": epoch
               }
    )
    
    teacher_model.is_training = True

teacher_model.is_training = False
test_loss, test_accuracy, size = 0, 0, 0

for i, (data, targets) in enumerate(test_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = teacher_model(data)
            test_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]

test_loss, test_accuracy = test_loss / size, test_accuracy / size

print(f'Training finished.')
print(f'Average Validation Loss: {test_loss:.4f}\t Validation Accuracy: {100*test_accuracy:.2f}%')

wandb.finish()

In [None]:
class StudentNetwork(nn.Module):
    def __init__(self):
        super(StudentNetwork, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 64)
        self.fc2 = nn.Linear(64, 10)
        self.dropout_input = 0.1
        self.dropout_hidden = 0.2
        self.is_training = True

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the image
        x = F.dropout(x, p=self.dropout_input, training=self.is_training)
        x = F.dropout(F.relu(self.fc1(x)), p=self.dropout_hidden, training=self.is_training)
        x = self.fc2(x)
        return x


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vanilla_model = StudentNetwork().to(device)
optimizer = optim.SGD(vanilla_model.parameters(), lr=0.01, momentum=0.5, weight_decay=5e-4)
num_epochs = 10

wandb.init(
    project="Renyi_Divergence_MNIST",
    name = "Vanilla Model",
    config={}
)

print("Starting training...")

for epoch in range(num_epochs):
    print(f'---Epoch {epoch+1}---')
    
    train_loss, train_accuracy, size = 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = vanilla_model(data)
        loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    vanilla_model.is_training = False
    
    val_loss, val_accuracy, size = 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = vanilla_model(data)
            val_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy = val_loss / size, val_accuracy / size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    wandb.log({"train_ce_loss": train_loss,
               "val_ce_loss": val_loss,
               "train_accuracy": 100*train_accuracy,
               "val_accuracy": 100*val_accuracy,
               "epoch": epoch
               }
    )
    
    vanilla_model.is_training = True

vanilla_model.is_training = False
test_loss, test_accuracy, size = 0, 0, 0

for i, (data, targets) in enumerate(test_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = vanilla_model(data)
            test_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]

test_loss, test_accuracy = test_loss / size, test_accuracy / size

print(f'Training finished.')
print(f'Average Validation Loss: {test_loss:.4f}\t Validation Accuracy: {100*test_accuracy:.2f}%')

wandb.finish()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = StudentNetwork().to(device)
optimizer = optim.SGD(student_model.parameters(), lr=0.01, momentum=0.5, weight_decay=5e-4)
num_epochs = 10
temperature = 3
beta = 0.85
alpha = 1

wandb.init(
    project="Renyi_Divergence_MNIST",
    name = "Student Model",
    config={
        "beta": beta,
        "temperature": temperature,
        "alpha": alpha
    }
)

print("Starting training...")

for epoch in range(num_epochs):
    print(f'---Epoch {epoch+1}---')
    
    train_loss, train_accuracy, size, train_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = student_model(data)
        with torch.no_grad():
            teacher_outputs = teacher_model(data)
        loss = (1-beta) * nn.CrossEntropyLoss()(outputs, targets) + beta * RenyiDivergence(alpha=alpha)(outputs/temperature, teacher_outputs/temperature) * (temperature**2) / alpha
        CE_loss = nn.CrossEntropyLoss()(outputs, targets)     
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_CE_loss += CE_loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_CE_loss = train_CE_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Average Train CE Loss: {train_CE_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    student_model.is_training = False
    
    val_loss, val_accuracy, size, val_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = student_model(data)
            teacher_pred = teacher_model(data)
            val_loss += ((1-beta) * nn.CrossEntropyLoss()(pred, targets) + beta * RenyiDivergence(alpha=alpha)(pred/temperature, teacher_pred/temperature) * (temperature**2) / alpha) * targets.shape[0]
            val_CE_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy, val_CE_loss = val_loss / size, val_accuracy / size, val_CE_loss/size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Average Validation CE Loss: {val_CE_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    wandb.log({"train_loss": train_loss,
               "train_ce_loss": train_CE_loss,
               "val_loss": val_loss,
               "val_ce_loss": val_CE_loss,
               "train_accuracy": 100*train_accuracy,
               "val_accuracy": 100*val_accuracy,
               "epoch": epoch
               }
    )
    
    student_model.is_training = True

student_model.is_training = False
test_loss, test_accuracy, size = 0, 0, 0

for i, (data, targets) in enumerate(test_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = student_model(data)
            test_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            test_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]

test_loss, test_accuracy = test_loss / size, test_accuracy / size

print(f'Training finished.')
print(f'Average Validation Loss: {test_loss:.4f}\t Validation Accuracy: {100*test_accuracy:.2f}%')

wandb.finish()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = StudentNetwork().to(device)
optimizer = optim.SGD(student_model.parameters(), lr=0.01, momentum=0.5, weight_decay=5e-4)
num_epochs = 10
temperature = 3
beta = 0.85
alpha = 1

In [None]:
def train(config=None,lr=None,momentum=None,weight_decay=None,temperature=None,beta=None,alpha=None):
    with wandb.init(config=config):
        global teacher_model
        config = wandb.config
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        #network = TeacherNetwork().to(device)
        network = StudentNetwork().to(device)
        #destilation, teacher_model = False, None
        destilation, teacher_model = True, teacher_model
        epochs = 10
        
        train_loader, val_loader, test_loader = build_dataset()
        
        if lr is not None:
            config.lr = lr
        if momentum is not None:
            config.momentum = momentum
        if weight_decay is not None:
            config.weight_decay = weight_decay
        if temperature is not None:
            config.temperature = temperature
        if beta is not None:
            config.beta = beta
        if alpha is not None:
            config.alpha = alpha
        
        optimizer = optim.SGD(network.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)
        
        for epoch in range(epochs):
            train_epoch(epoch, network, train_loader, val_loader, optimizer, device, destilation, teacher_model, config)
            

In [None]:
def build_dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.95,0.05], generator=torch.Generator().manual_seed(872))

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(872))
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True, generator=torch.Generator().manual_seed(872))
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    return train_loader, val_loader, test_loader
    
def train_epoch(epoch, network, train_loader, val_loader, optimizer, device, destilation, teacher_model, config):
    train_loss, train_accuracy, size, train_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = network(data)
        if destilation:
            with torch.no_grad():
                teacher_outputs = teacher_model(data)
            loss = (1-config.beta) * nn.CrossEntropyLoss()(outputs, targets) + config.beta * RenyiDivergence(alpha=config.alpha)(outputs/config.temperature, teacher_outputs/config.temperature) * (config.temperature**2) / config.alpha
        else:
            loss = nn.CrossEntropyLoss()(outputs, targets)
        CE_loss = nn.CrossEntropyLoss()(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        train_CE_loss += CE_loss.item() * targets.size(0)
        train_accuracy += torch.sum(torch.argmax(outputs, dim=1) == targets).item()
        size += targets.size(0)
        
    train_loss = train_loss/size
    train_CE_loss = train_CE_loss/size
    train_accuracy = train_accuracy/size
    
    print(f'Average Train Loss: {train_loss:.4f} \t \t Average Train CE Loss: {train_CE_loss:.4f} \t \t Train Accuracy: {100*train_accuracy:.2f}%')
    
    network.is_training = False
    
    val_loss, val_accuracy, size, val_CE_loss = 0, 0, 0, 0
    for i, (data, targets) in enumerate(val_loader):
        data, targets = data.to(device), targets.to(device)
        with torch.no_grad():
            pred = network(data)
            teacher_pred = teacher_model(data)
            val_loss += ((1-config.beta) * nn.CrossEntropyLoss()(pred, targets) + config.beta * RenyiDivergence(alpha=config.alpha)(pred/config.temperature, teacher_pred/config.temperature) * (config.temperature**2) / config.alpha) * targets.shape[0]
            val_CE_loss += nn.CrossEntropyLoss()(pred, targets) * targets.shape[0]
            val_accuracy += torch.sum(torch.argmax(pred, dim=1) == targets).item()
        size += targets.shape[0]
    
    val_loss, val_accuracy, val_CE_loss = val_loss / size, val_accuracy / size, val_CE_loss/size
    
    print(f'Average Validation Loss: {val_loss:.4f} \t Average Validation CE Loss: {val_CE_loss:.4f} \t Validation Accuracy: {100*val_accuracy:.2f}%')
    
    wandb.log({"train_loss": train_loss,
               "train_ce_loss": train_CE_loss,
               "val_loss": val_loss,
               "val_ce_loss": val_CE_loss,
               "train_accuracy": 100*train_accuracy,
               "val_accuracy": 100*val_accuracy,
               "epoch": epoch
               }
    )
    
    network.is_training = True

In [None]:
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'beta': {'distribution': 'uniform', 'min': 0.5, 'max': 1},
        'temperature': {'distribution': 'int_uniform', 'min': 1, 'max': 25}
    }
}

In [None]:
sweep_id = wandb.sweep(sweep_config, project="Renyi_Divergence_Sweep_Student")

In [None]:
wandb.agent(sweep_id, function=lambda: train(lr=0.01,momentum=0.5,weight_decay=5e-4,alpha=1), count=30)

In [None]:
wandb.finish()

In [None]:
for batch_idx, (data, target) in enumerate(train_loader):
    student_model.eval()
    with torch.no_grad():
        output = student_model(data[0].to(device))
        teacher_model_output = teacher_model(data[0].to(device))
    break

In [None]:
plt.bar(np.arange(10), np.squeeze(output.cpu().numpy()), width=0.35, color='blue', label='Student Model Logits')
plt.bar(np.arange(10)+0.35, np.squeeze(teacher_model_output.cpu().numpy()), width=0.35, color='green', label='Teacher Model Logits')
plt.legend()
plt.title("Logits")
plt.show()

In [None]:
plt.bar(np.arange(10), np.squeeze(torch.softmax(output,dim=1).cpu().numpy()), width=0.35, color='blue', label='Student Model Logits')
plt.bar(np.arange(10)+0.35, np.squeeze(torch.softmax(teacher_model_output,dim=1).cpu().numpy()), width=0.35, color='green', label='Teacher Model Logits')
plt.legend()
plt.title("Sotmax without temperature")
plt.show()

In [None]:
temperature = 5

plt.bar(np.arange(10), np.squeeze(torch.softmax(output/temperature,dim=1).cpu().numpy()), width=0.35, color='blue', label='Student Model Logits')
plt.bar(np.arange(10)+0.35, np.squeeze(torch.softmax(teacher_model_output/temperature,dim=1).cpu().numpy()), width=0.35, color='green', label='Teacher Model Logits')
plt.legend()
plt.title("Sotmax with temperature")
plt.show()