In [None]:
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import time
import matplotlib.pyplot as plt

In [None]:
BATCH_SIZE = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 0.0001
EPOCHS = 100
EVALUATION_DELAY = 10
print(DEVICE)

In [None]:
with open("../../data/custom_fragments_30/embeddings/final_embeddings_train.pkl", "rb") as f:
    train_embeddings = pickle.load(f)

In [None]:
with open("../../data/custom_fragments_30/embeddings/final_embeddings_val.pkl", "rb") as f:
    val_embeddings = pickle.load(f)

In [None]:
train_embeddings[0]

In [None]:
len(train_embeddings)

In [None]:
len(val_embeddings)

In [None]:
class RandomDataset(Dataset):
    def __init__(self, embeddings):
        self.labels = []
        self.embeddings = []

        for elem in embeddings:
            self.labels.append(torch.tensor([elem[1]], dtype=torch.float32))
            self.embeddings.append(elem[3])

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

In [None]:

train_dataset = RandomDataset(train_embeddings)

In [None]:

val_dataset = RandomDataset(val_embeddings)

In [None]:
train_dataset[0]

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

In [None]:
class RandomModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
model1 = RandomModel()
model1.to(DEVICE)

In [None]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model1.parameters(), lr=LEARNING_RATE)

In [None]:
def train_binary(dataloader, model, optimizer, loss_fn, device: str):
    model.train()
    total_loss = 0.0
    total_samples = 0
    total_correct = 0

    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device).float()

        logits = model(x)
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            preds = torch.sigmoid(logits)
            predicted_labels = (preds >= 0.5).float()
            total_correct += (predicted_labels == y).sum().item()

        batch_size = y.size(0)
        total_loss += loss.item() * batch_size
        total_samples += batch_size

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples

    return avg_loss, accuracy

In [None]:
def evaluate_binary(dataloader, model, loss_fn, device: str):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    total_correct = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device).float()

            logits = model(x)
            loss = loss_fn(logits, y)

            preds = torch.sigmoid(logits)
            predicted_labels = (preds >= 0.5).float()

            total_correct += (predicted_labels == y).sum().item()

            batch_size = y.size(0)
            total_loss += loss.item() * batch_size
            total_samples += batch_size

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples

    return avg_loss, accuracy

In [None]:
prev_loss = 10000

In [None]:
train_losses = []
train_accuracies = []
eval_losses = []
eval_accuracies = []
start = 0

In [None]:
start_time = time.time()
for i in range(start, EPOCHS + start):
    print(f"Epoch: {i}")

    train_loss, train_accuracy = train_binary(train_loader, model1, optimizer, loss_fn, DEVICE)
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)

    print(f"Training: Avg_Loss: {train_loss:.8f}; Accuracy: {train_accuracy:.8f}")

    if i % EVALUATION_DELAY == 0:
        eval_loss, eval_accuracy = evaluate_binary(val_loader, model1, loss_fn, DEVICE)
        eval_losses.append(eval_loss)
        eval_accuracies.append(eval_accuracy)

        print(f"Validation: Avg_Loss: {eval_loss:.8f}; Accuracy: {eval_accuracy:.8f}")

        if eval_loss < prev_loss:
            torch.save(model1.state_dict(), f"./model_saves2/random_model_epoch_{i}.pt")
            print(f"\033[92mSaved Model at epoch {i}\033[m")

            prev_loss = eval_loss

print(f"\033[92mTraining complete; took: {time.time() - start_time:.2f} seconds; avg: {(time.time() - start_time) / EPOCHS:.2f} seconds/epoch; best: {prev_loss}\033[0m")

In [None]:
plt.figure(figsize=(15, 8))

plt.plot(train_losses)
plt.plot(train_accuracies)
plt.plot(eval_losses)
plt.show()

In [None]:
fragment_counts = sum([label == 1 for embedding, label in train_dataset])
fragment_counts

In [None]:
non_fragment_counts = sum([label != 1 for embedding, label in train_dataset])
non_fragment_counts

In [None]:
fragment_counts / non_fragment_counts

In [None]:
fragment_counts / (non_fragment_counts + fragment_counts)

In [None]:

torch.save(model1.state_dict(), "./model_saves2/final_model.pt")

In [None]:
with open("../../experiments/true_fragment_embeddings/final_embeddings.pkl", "rb") as p:
    true_fragment_embeddings = pickle.load(p)

true_fragment_dataset = RandomDataset(true_fragment_embeddings)

In [None]:
true_fragment_embeddings[0]

In [None]:
true_fragment_loader = DataLoader(true_fragment_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

In [None]:

model1.load_state_dict(torch.load("./model_saves2/random_model_epoch_10.pt"))

In [None]:
test_loss, test_accuracy = evaluate_binary(true_fragment_loader, model1, loss_fn, DEVICE)

In [None]:
test_loss

In [None]:
test_accuracy

In [None]:
# Visualize confusion matrix
fig, ax = plt.subplots(figsize=(10, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=cm_val, display_labels=['Non-Fragment', 'Fragment'])
disp.plot(cmap='Blues', ax=ax, values_format='d')
ax.set_title('Confusion Matrix - Validation Set', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Calculate metrics for both classes
precision = precision_score(y_true_val, y_pred_val, average=None)
recall = recall_score(y_true_val, y_pred_val, average=None)
f1 = f1_score(y_true_val, y_pred_val, average=None)

print("\nMetrics by Class:")
print(f"Class 0 (Non-Fragment) - Precision: {precision[0]:.4f}, Recall: {recall[0]:.4f}, F1: {f1[0]:.4f}")
print(f"Class 1 (Fragment)     - Precision: {precision[1]:.4f}, Recall: {recall[1]:.4f}, F1: {f1[1]:.4f}")

# Calculate macro and weighted averages
precision_macro = precision_score(y_true_val, y_pred_val, average='macro')
recall_macro = recall_score(y_true_val, y_pred_val, average='macro')
f1_macro = f1_score(y_true_val, y_pred_val, average='macro')

precision_weighted = precision_score(y_true_val, y_pred_val, average='weighted')
recall_weighted = recall_score(y_true_val, y_pred_val, average='weighted')
f1_weighted = f1_score(y_true_val, y_pred_val, average='weighted')

print("\nMacro Average:")
print(f"Precision: {precision_macro:.4f}, Recall: {recall_macro:.4f}, F1: {f1_macro:.4f}")
print("\nWeighted Average:")
print(f"Precision: {precision_weighted:.4f}, Recall: {recall_weighted:.4f}, F1: {f1_weighted:.4f}")

In [None]:
# Calculate confusion matrix
cm_val = confusion_matrix(y_true_val, y_pred_val)
print("Confusion Matrix (Validation Set):")
print(cm_val)

In [None]:
# Get predictions on validation set
y_true_val, y_pred_val = get_predictions(val_loader_eval, model1, DEVICE)

In [None]:
# Create a fresh validation loader without shuffling for consistent predictions
val_loader_eval = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

In [None]:
# Get predictions for validation dataset
def get_predictions(dataloader, model, device: str):
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device).float()
            
            logits = model(x)
            preds = torch.sigmoid(logits)
            predicted_labels = (preds >= 0.5).float()
            
            all_predictions.extend(predicted_labels.cpu().numpy().flatten())
            all_labels.extend(y.cpu().numpy().flatten())
    
    return np.array(all_labels), np.array(all_predictions)

In [None]:
# Import sklearn metrics for confusion matrix
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, ConfusionMatrixDisplay
import numpy as np