# GPU Setting

In [None]:
import tensorflow as tf

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
USE_GPU = True

if USE_GPU and tf.config.list_physical_devices('GPU'):
    device = 'GPU'
    print("Using GPU")
else:
    device = 'CPU'
    print("Using CPU")

# Loading Dataset

In [None]:
import tarfile
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import random

# Path to the downloaded tgz file
tgz_path = "/home/asko/Documents/workspace/Fall-24/682/project/dataset/imagenette2.tgz"
extract_path = "./imagenette"  # Target folder for extraction

# Extract the file
with tarfile.open(tgz_path, "r:gz") as tar:
    tar.extractall(path=extract_path)
print("Extraction completed.")


transform = transforms.Compose([
    transforms.Resize((224,224)),  # Resize images to a size suitable for VGG16
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize as per VGG16
])
extract_path_train = "./imagenette/imagenette2/train"
train_dataset = datasets.ImageFolder(
    root=extract_path_train,  # Imagenette URL
    transform=transform
)
extract_path_val = "./imagenette/imagenette2/val"
val_dataset = datasets.ImageFolder(
    root=extract_path_val,  # Imagenette URL
    transform=transform
)

batch_size = 16
sampler_train = sampler.SubsetRandomSampler(range(len(train_dataset)))
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler_train)

batch_size = 16
sampler_val = sampler.SubsetRandomSampler(range(1000))
val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=sampler_val)

batch_size = 32
sampler_test = sampler.SubsetRandomSampler(range(1000, len(val_dataset)), len(val_dataset)-1000)
test_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=sampler_test)


In [None]:
import torch

torch.cuda.set_per_process_memory_fraction(0.95, device=0) 

In [None]:
import tensorflow as tf
from TACNNModel.TACNN import TACNN

alpha = 0.03
temperature = 0.02

# Load the VGG16 Teacher model
import dill
with open("model.pkl", "rb") as f:
    teacher_model = dill.load(f)

# Load the VGG16 model
ta_model = TACNN(alpha=alpha, temperature=temperature, num_of_classes=10)

device = 'cuda' if len(tf.config.list_physical_devices('GPU'))!=0 else 'cpu'
teacher_model.to(device)
ta_model.to(device)
print(device)



In [None]:
teacher_outputs = []

for batch_idx, (images, labels) in enumerate(train_loader):
    teacher_outputs.append(teacher_model(images))

In [None]:
def check_accuracy_part34(loader, model):
    print('Checking accuracy on validation set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))


In [None]:
import torch.optim as optim
import torch
import torch.nn.functional as F

optimizer = optim.Adam(ta_model.parameters(), lr=0.01)
# optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-5, momentum=0.98, nesterov=False)
torch.set_grad_enabled(True)

num_epochs = 50
batch_size = 16  
model_weights = None
print_every = 100

for epoch in range(num_epochs):
    total_loss = 0.0 
    # if model_weights is None:
    #     model_weights = model.conv1_1.weight.clone()   
    for batch_idx, (images, labels) in enumerate(train_loader):
        ta_model.train()           
        
        images = images.to(device, dtype=torch.float32)
        labels = labels.to(device, dtype=torch.long)
        
        ta_output = ta_model(images)
        teacher_output = teacher_outputs[batch_idx]

        loss = ta_model.risk(Y=labels, teacher_preds=teacher_output, output=ta_output)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # total_loss += loss.item()
        if batch_idx % print_every == 0:
                print('Iteration %d, loss = %.4f' % (batch_idx, loss.item()))
                check_accuracy_part34(val_loader, ta_model)
                print()
    print(f"Epoch {epoch} complete")
    # print(torch.sum(model_weights - model.conv1_1.weight))   
    # model_weights = model.conv1_1.weight.clone()        
    # print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / (batch_idx+1)}, Training Accuracy: {check_accuracy_part34(train_loader, model)}")
    # print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / (batch_idx+1)}, Validation Accuracy: {check_accuracy_part34(val_loader, model)}")