In [0]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn as nn
import torch.optim as optim
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()
os.environ["HF_DATASETS_CACHE"] = "/dbfs/hf_datasets"
os.environ["TOKENIZERS_PARALLELISM"] = "false" 

from trainer import Trainer, TrainingArguments
from bacp import BaCPTrainer, BaCPTrainingArguments
from utils import *
from constants import *

device = get_device()
print(f"{device = }")

In [0]:
# Notebook specific variables
MODEL_NAME = 'vgg11'
MODEL_TASK = 'cifar10'
TRAIN = False

In [0]:
finetuned_weights = f"/dbfs/research/{MODEL_NAME}/{MODEL_TASK}/{MODEL_NAME}_{MODEL_TASK}_baseline.pt"

bacp_training_args = BaCPTrainingArguments(
    model_name=MODEL_NAME,
    model_task=MODEL_TASK,
    batch_size=BATCH_SIZE,
    optimizer_type='sgd',
    learning_rate=0.1,
    pruning_type='magnitude_pruning',
    target_sparsity=TARGET_SPARSITY_LOW,
    sparsity_scheduler='cubic',
    finetuned_weights=finetuned_weights,
    learning_type='bacp_pruning'
)
bacp_trainer = BaCPTrainer(bacp_training_args=bacp_training_args)
if TRAIN:
    bacp_trainer.train()

# Finetuning Phase
bacp_trainer.generate_mask_from_model()
training_args = TrainingArguments(
    model_name=bacp_trainer.model_name,
    model_task=bacp_trainer.model_task,
    batch_size=bacp_trainer.batch_size,
    optimizer_type='adamw',
    learning_rate=0.0001,
    pruner=bacp_trainer.get_pruner(),
    pruning_type=bacp_trainer.pruning_type,
    target_sparsity=bacp_trainer.target_sparsity,
    epochs=50,
    finetuned_weights=bacp_trainer.save_path,
    finetune=True,
    learning_type="bacp_finetune",
)
trainer = Trainer(training_args)
if TRAIN:
    trainer.train()

metrics = trainer.evaluate()
print(f"\n{metrics}")

In [0]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from models import *
from dataset_utils import *

data = get_cv_data(MODEL_TASK, 1028, learning_type='contrastive')
trainlaoder = data['trainloader']
data_batch = next(iter(trainlaoder))

samples, labels = data_batch
samples_1, samples_2 = samples
samples_1 = samples_1.to(device)
samples_2 = samples_2.to(device)

back_weights = f'/dbfs/research/{MODEL_NAME}/{MODEL_TASK}/{MODEL_NAME}_{MODEL_TASK}_magnitude_pruning_0.99_bacp_pruning.pt'
encoder_model = EncoderProjectionNetwork(MODEL_NAME).to(device)
load_weights(encoder_model, back_weights)
encoder_model.to(device)
encoder_model.eval()

with torch.no_grad():
    embeddings = encoder_model(samples_1)
    embeddings = embeddings.detach().cpu().numpy()

tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42)
embeddings_2d = tsne.fit_transform(embeddings)

In [0]:
plt.figure(figsize=(8, 6))
scatter = plt.scatter(
    embeddings_2d[:, 0], embeddings_2d[:, 1],
    c=labels, cmap='tab10', alpha=0.7
)
plt.colorbar(scatter, label='Class Label')
plt.title(f't-SNE Visualization ({MODEL_NAME} - {MODEL_TASK})')
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.grid(True)
plt.show()

plt.savefig('tsne_plot_vgg11_cifar10_mag99.png', dpi=300)



In [0]:
sns.set(style='whitegrid')
plt.figure(figsize=(8, 6))
sns.scatterplot(
    x=embeddings_2d[:, 0], y=embeddings_2d[:, 1],
    hue=labels, palette='tab10', s=60, edgecolor='k', alpha=0.8
)
plt.title('t-SNE Visualization of Embeddings')
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.legend(title='Class')
plt.show()
