In [14]:
# Imports
import pandas as pd
import numpy as np
import seaborn as sns
import pickle
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from tqdm import tqdm
import pickle
import torch
import torch.nn.functional as F

In [15]:
with open("pkl files/chunks.pkl", "rb") as f:
    chunks = pickle.load(f)

In [16]:
# Importing models
from vqmodel import *

In [17]:
# Dataframe to tensor transition
images = []

for df in chunks:
    # df['ch'] is a Series of 80 columns; each item is a list of 200 values
    # Create a (200, 80) NumPy array (transpose is needed)
    matrix = np.stack(df['PWR_ch1'].to_list(), axis=1)  # shape: (200, 80)
    images.append(matrix)

# Convert the whole thing to numpy because making tensors from a list of arrays
# is extremely slow
images_array = np.array(images)


In [18]:
# Convert to torch tensor and add batch + channel dimensions
data_tensor = torch.tensor(images_array, dtype=torch.float32)  # (B, 200, 80)
data_tensor = data_tensor.unsqueeze(1)  # (B, 1, 200, 80)

In [19]:
from torch.utils.data import DataLoader

dataset = ChunkImageDataset(chunks)
loader = DataLoader(dataset, batch_size=8, shuffle=True)
model = VQVAE(in_channels=1)

for batch in loader:
    print(batch.shape)  # (8, 1, 200, 80)
    outputs = model(batch)
    break  # for test

torch.Size([8, 1, 200, 80])


In [20]:
from sklearn.model_selection import train_test_split

# 1. First split: 80% train, 20% temp
train_chunks, temp_chunks = train_test_split(
    chunks, test_size=0.2, random_state=42
)

In [21]:
# 2. Split temp into 10% val, 10% test
val_chunks, test_chunks = train_test_split(
    temp_chunks, test_size=0.5, random_state=42
)

In [22]:

# Check counts
print(f"Train: {len(train_chunks)}")
print(f"Validation: {len(val_chunks)}")
print(f"Test: {len(test_chunks)}")

Train: 2478
Validation: 310
Test: 310


In [23]:
from torch.utils.data import DataLoader

# Wrap into datasets
train_dataset = ChunkImageDataset(train_chunks)
val_dataset = ChunkImageDataset(val_chunks)
test_dataset = ChunkImageDataset(test_chunks)

In [24]:
# Loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)
test_loader = DataLoader(test_dataset, batch_size=8)

# Model Training

In [27]:
def train_vqvae(model, train_loader, val_loader, optimizer, device="cuda", epochs=500):
    model.to(device)

    # Initialize plot variables
    train_losses, val_losses = [], []
    recon_losses, vq_losses = [], []
    code_usages, zq_stds = [], []

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss, recon_total, vq_total = 0, 0, 0
        code_indices_set = set()
        zq_std_list = []

        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch)
            out["total_loss"].backward()
            optimizer.step()

            total_loss += out["total_loss"].item() * batch.size(0)
            recon_total += out["recon_loss"].item() * batch.size(0)
            vq_total += out["vq_loss"].item() * batch.size(0)
            code_indices_set.update(out["indices"].detach().cpu().numpy().tolist())
            zq_std_list.append(out["z_q"].std().item())

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                out = model(batch)
                val_loss += F.mse_loss(out["recon_x"], batch).item() * batch.size(0)

        train_losses.append(total_loss / len(train_loader.dataset))
        val_losses.append(val_loss / len(val_loader.dataset))
        recon_losses.append(recon_total / len(train_loader.dataset))
        vq_losses.append(vq_total / len(train_loader.dataset))
        code_usages.append(len(code_indices_set))
        zq_stds.append(np.mean(zq_std_list))

        print(f"[Epoch {epoch}] Train Loss: {train_losses[-1]:.4f} | "
              f"Val Loss: {val_losses[-1]:.4f} | "
              f"Codes Used: {code_usages[-1]}")
        print(f"z_q mean: {out['z_q'].mean().item():.6f} | std: {zq_stds[-1]:.6f}")

    # Plotting
    epochs_range = range(1, epochs + 1)
    plt.figure(figsize=(15, 10))

    plt.subplot(2, 2, 1)
    plt.plot(epochs_range, train_losses, label='Train Loss')
    plt.plot(epochs_range, val_losses, label='Val Loss')
    plt.title("Total Loss")
    plt.legend()

    plt.subplot(2, 2, 2)
    plt.plot(epochs_range, recon_losses, label='Reconstruction Loss')
    plt.title("Reconstruction Loss")
    plt.legend()

    plt.subplot(2, 2, 3)
    plt.plot(epochs_range, vq_losses, label='VQ Loss')
    plt.title("VQ Loss")
    plt.legend()

    plt.subplot(2, 2, 4)
    plt.plot(epochs_range, code_usages, label='Unique Codes Used')
    plt.title("Codebook Usage")
    plt.legend()

    plt.tight_layout()
    plt.show()


In [None]:
model = VQVAE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

train_vqvae(model, train_loader, val_loader, optimizer, epochs=10)

[Epoch 1] Train Loss: 1.4301 | Val Loss: 0.7682 | Codes Used: 60
z_q mean: 0.065326 | std: 0.502104
[Epoch 2] Train Loss: 1.8675 | Val Loss: 0.7133 | Codes Used: 69
z_q mean: 0.051838 | std: 0.519405
[Epoch 3] Train Loss: 1.6549 | Val Loss: 0.6937 | Codes Used: 73
z_q mean: 0.046422 | std: 0.531193
[Epoch 4] Train Loss: 1.6242 | Val Loss: 0.6738 | Codes Used: 59
z_q mean: 0.030272 | std: 0.544589
[Epoch 5] Train Loss: 1.5432 | Val Loss: 0.6541 | Codes Used: 59
z_q mean: 0.017414 | std: 0.548272


In [None]:
torch.save(model.state_dict(), "model pt files/vqvae_model.pth")