In [None]:
from google.colab import drive
import os
import glob
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
drive.mount('/content/drive')

data_directory = "/content/drive/MyDrive/data"

game_to_label = {"pinpong": 0, "carracingv3": 1, "airraid": 2}
pkl_files = glob.glob(os.path.join(data_directory, "*.pkl"))

Mounted at /content/drive


In [None]:

class LazyGameDataset(Dataset):
    def __init__(self, pkl_files, game_to_label):
        self.index = []
        self.pkl_files = pkl_files
        self.game_to_label = game_to_label
        print("Building index for lazy dataset...")
        for file in pkl_files:
            game_name = os.path.basename(file).split('.')[0]
            label = game_to_label[game_name]
            print(f"Indexing file: {file}")
            with open(file, "rb") as f:
                rollouts = pickle.load(f)
            for ep_idx, episode in enumerate(rollouts):
                for trans_idx, transition in enumerate(episode):
                    self.index.append((file, ep_idx, trans_idx, label))
        print(f"Total samples indexed: {len(self.index)}")
        self.current_file = None
        self.current_data = None

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        file, ep_idx, trans_idx, label = self.index[idx]
        if file != self.current_file:
            with open(file, "rb") as f:
                self.current_data = pickle.load(f)
            self.current_file = file
        rollouts = self.current_data
        episode = rollouts[ep_idx]
        transition = episode[trans_idx]
        obs, action, reward, next_obs, done = transition
        obs = np.array(obs)
        obs = np.transpose(obs, (2, 0, 1))
        obs = torch.tensor(obs, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)
        return obs, label

In [None]:
dataset = LazyGameDataset(pkl_files, game_to_label)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

selected_indices = {}
for i, (_, _, _, label) in enumerate(dataset.index):
    if label not in selected_indices:
        selected_indices[label] = i
    if len(selected_indices) == len(game_to_label):
        break
selected_indices = [selected_indices[i] for i in range(len(game_to_label))]
print("Selected sample indices for reconstruction:", selected_indices)
selected_samples = [dataset[i] for i in selected_indices]
selected_states = torch.stack([s for s, l in selected_samples])
selected_labels = torch.tensor([l for s, l in selected_samples], dtype=torch.long)

Building index for lazy dataset...
Indexing file: /content/drive/MyDrive/data/carracingv3.pkl
Indexing file: /content/drive/MyDrive/data/airraid.pkl
Indexing file: /content/drive/MyDrive/data/pinpong.pkl
Total samples indexed: 141513
Selected sample indices for reconstruction: [96413, 0, 48330]


In [None]:
latent_dim = 128
num_classes = 3

In [None]:
class CVAE(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(CVAE, self).__init__()
        self.num_classes = num_classes
        self.encoder = nn.Sequential(
            nn.Conv2d(3 + num_classes, 32, kernel_size=4, stride=2, padding=1),  # -> [32, 32, 32]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),               # -> [64, 16, 16]
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),              # -> [128, 8, 8]
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),             # -> [256, 4, 4]
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)
        self.fc_decode = nn.Linear(latent_dim + num_classes, 256 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),    # -> [128, 8, 8]
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),     # -> [64, 16, 16]
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),      # -> [32, 32, 32]
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),       # -> [3, 64, 64]
            nn.Sigmoid()
        )

    def encode(self, x, label):
        one_hot = torch.nn.functional.one_hot(label, num_classes=self.num_classes).float()
        one_hot = one_hot.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.size(2), x.size(3))
        x = torch.cat([x, one_hot], dim=1)
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps

    def decode(self, z, label):
        one_hot = torch.nn.functional.one_hot(label, num_classes=self.num_classes).float()
        z = torch.cat([z, one_hot], dim=1)
        x = self.fc_decode(z)
        x = x.view(-1, 256, 4, 4)
        x = self.decoder(x)
        return x

    def forward(self, x, label):
        mu, logvar = self.encode(x, label)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z, label)
        return x_recon, mu, logvar

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CVAE(latent_dim, num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [1]:
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    total_loss = BCE + KLD
    return total_loss, BCE, KLD


epochs = 80
loss_list = []
bce_list = []
kld_list = []
model.train()

for epoch in range(1, epochs + 1):
    train_loss = 0
    total_bce = 0
    total_kld = 0
    for states_batch, labels_batch in dataloader:
        states_batch = states_batch.to(device)
        labels_batch = labels_batch.to(device)

        optimizer.zero_grad()
        recon_batch, mu, logvar = model(states_batch, labels_batch)
        loss, bce_loss, kld_loss = loss_function(recon_batch, states_batch, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        total_bce += bce_loss.item()
        total_kld += kld_loss.item()
        optimizer.step()

    avg_loss = train_loss / len(dataset)
    avg_bce = total_bce / len(dataset)
    avg_kld = total_kld / len(dataset)
    loss_list.append(avg_loss)
    bce_list.append(avg_bce)
    kld_list.append(avg_kld)
    print(f"Epoch {epoch}, Avg Total Loss: {avg_loss:.4f}, BCE: {avg_bce:.4f}, KLD: {avg_kld:.4f}")


    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            states_tensor = selected_states.to(device)
            labels_tensor = selected_labels.to(device)
            recon, _, _ = model(states_tensor, labels_tensor)
            for i in range(len(selected_labels)):

                original = selected_states[i].permute(1, 2, 0).cpu().numpy()
                reconstructed = recon[i].permute(1, 2, 0).cpu().numpy()
                plt.figure(figsize=(10, 5))
                plt.subplot(1, 2, 1)
                plt.imshow(original.astype(np.float32))
                plt.title(f"Original - Game {selected_labels[i].item()}")
                plt.axis('off')
                plt.subplot(1, 2, 2)
                plt.imshow(reconstructed.astype(np.float32))
                plt.title(f"Reconstructed - Epoch {epoch}")
                plt.axis('off')
                plt.savefig(f"/content/drive/MyDrive/recon_epoch_{epoch}_game_{selected_labels[i].item()}.png")
                plt.close()
        model.train()


    if epoch % 10 == 0:
        save_path = f"/content/drive/MyDrive/general_CVAE_epoch_{epoch}.pth"
        torch.save(model.state_dict(), save_path)
        print(f"Model saved at epoch {epoch}!")


torch.save(model.state_dict(), "/content/drive/MyDrive/general_CVAE.pth")
print("Model saved at the end of training!")

NameError: name 'model' is not defined

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(range(1, epochs + 1), loss_list, label='Total Loss')
plt.plot(range(1, epochs + 1), bce_list, label='BCE Loss')
plt.plot(range(1, epochs + 1), kld_list, label='KLD Loss')
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.title("Training Losses over Epochs")
plt.legend()
plt.savefig("/content/drive/MyDrive/loss_plot.png")
plt.show()