In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm



from utils import experiment_classif_simple, format_results, \
    plot_loss_acc_over_epochs, plot_time_vs_parameters


if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
# device = 'cpu'
print(f"Using device: {device}")

## Basic operations

In [None]:
_ = 1 + 1

In [None]:
for _ in tqdm(range(10000000)):
    _ = 1 + 1

## Setup MNIST Experiment

### Dataset

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Plot examples of digits (with true labels)
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
fig = plt.figure()
for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.imshow(example_data[i].cpu().squeeze(), cmap='gray')
    plt.title(f"True: {example_targets[i].item()}")
    plt.axis("off")
plt.tight_layout()

## Classification tasks

### Expe - Simple MLP

In [None]:
class SimpleMLP(torch.nn.Module):
    def __init__(self, hidden_dims=[], input_dim=28*28, output_dim=10):
        super(SimpleMLP, self).__init__()
        layers = []
        prev_dim = input_dim
        for h_dim in hidden_dims:
            layers.append(torch.nn.Linear(prev_dim, h_dim))
            layers.append(torch.nn.ReLU())
            prev_dim = h_dim
        layers.append(torch.nn.Linear(prev_dim, output_dim))
        self.network = torch.nn.Sequential(*layers)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.network(x)

In [None]:
hidden_dims_tested = [
    [32, 16], 
    [64, 32],
    [128, 64, 32]
]
all_outputs_mlp = []
all_models_mlp = []

for hidden_dims in hidden_dims_tested:
    print(f"Training SimpleMLP with hidden dimensions: {hidden_dims}")
    model = SimpleMLP(hidden_dims=hidden_dims)
    output = experiment_classif_simple(
        SimpleMLP(hidden_dims=hidden_dims), 
        train_loader, test_loader,
        nbr_epochs=10, device=device,
        run_name=f"SimpleMLP_{'_'.join(map(str, hidden_dims))}"
    )
    all_outputs_mlp.append(output)
    all_models_mlp.append(model)

results_long_mlp, results_summary_mlp = format_results(all_outputs_mlp)

### Expe - Simple CNN

In [None]:
class SimpleCNN(torch.nn.Module):
    def __init__(self, input_channels=1, hidden_channels=[32, 64], output_dim=10):
        super(SimpleCNN, self).__init__()
        conv_part = []
        in_channels = input_channels
        for i in range(len(hidden_channels)):
            conv_part.append(torch.nn.Conv2d(in_channels, hidden_channels[i], kernel_size=3, padding=1))
            conv_part.append(torch.nn.ReLU())
            in_channels = hidden_channels[i]
        self.last_channels = hidden_channels[-1]
        self.conv_part = torch.nn.Sequential(*conv_part)
        self.pool = torch.nn.MaxPool2d(2, 2)
        mlp_part = []
        mlp_part.append(torch.nn.Linear(self.last_channels * 14 * 14, 128))
        mlp_part.append(torch.nn.ReLU())
        mlp_part.append(torch.nn.Dropout(0.25))
        mlp_part.append(torch.nn.Linear(128, output_dim))
        self.mlp_part = torch.nn.Sequential(*mlp_part)
    def forward(self, x):
        x = self.conv_part(x) # shape: (batch_size, last_channels, 28, 28)
        x = self.pool(x)      # shape: (batch_size, last_channels, 14, 14)
        x = x.view(-1, self.last_channels * 14 * 14) # flatten
        x = self.mlp_part(x)
        return x

In [None]:
hidden_channels_tested = [
    [8], 
    [16, 16]
]
all_outputs_cnn = []
all_models_cnn = []

for hidden_chans in hidden_channels_tested:
    print(f"Training SimpleCNN with hidden channels: {hidden_chans}")
    model = SimpleCNN(hidden_channels=hidden_chans)
    output = experiment_classif_simple(
        model, 
        train_loader, test_loader,
        nbr_epochs=10, device=device,
        run_name=f"SimpleCNN_{'_'.join(map(str, hidden_chans))}"
    )
    all_outputs_cnn.append(output)
    all_models_cnn.append(model)

results_long_cnn, results_summary_cnn = format_results(all_outputs_cnn)

### Summary

In [None]:
results_summary = pd.concat([results_summary_mlp, results_summary_cnn], ignore_index=True)
results_long = pd.concat([results_long_mlp, results_long_cnn], ignore_index=True)

In [None]:
plot_time_vs_parameters(results_summary)

In [None]:
plot_loss_acc_over_epochs(results_long)

## Image generation tasks

### CVAE

In [None]:
class CVAE_MLP(nn.Module):
    def __init__(self, input_shape=(1, 28, 28), hidden_dims=[400],
                 latent_dim=20, num_classes=10):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.input_shape = input_shape
        self.input_dim = torch.prod(torch.tensor(input_shape)).item()
        
        # Encoder: image (784) + one hot label (10)
        encoder = []
        in_dim = self.input_dim + num_classes
        for h in hidden_dims:
            encoder.append(nn.Linear(in_dim, h))
            encoder.append(nn.ReLU())
            in_dim = h
        self.encoder = nn.Sequential(*encoder)
        self.fc_mu = nn.Linear(in_dim, latent_dim)
        self.fc_logvar = nn.Linear(in_dim, latent_dim)

        # Decoder: latent z + label (one hot)
        decoder = []
        in_dim = latent_dim + num_classes
        for h in reversed(hidden_dims):
            decoder.append(nn.Linear(in_dim, h))
            decoder.append(nn.ReLU())
            in_dim = h
        self.decoder = nn.Sequential(*decoder)
        self.fc_out = nn.Linear(in_dim, self.input_dim)

    def encode(self, x, y_onehot):
        x = x.view(-1, self.input_dim)
        inp = torch.cat([x, y_onehot], dim=1)
        h = F.relu(self.encoder(inp))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, y_onehot):
        inp = torch.cat([z, y_onehot], dim=1)
        h = F.relu(self.decoder(inp))
        x_hat = torch.sigmoid(self.fc_out(h))
        return x_hat.view(-1, *self.input_shape)

    def forward(self, x, y_onehot):
        mu, logvar = self.encode(x, y_onehot)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z, y_onehot)
        return x_hat, mu, logvar


In [None]:
class CVAE_CNN(nn.Module):
    def __init__(
        self,
        input_channels=1,
        hidden_channels=[32, 64],
        latent_dim=16,
        num_classes=10,
    ):
        super().__init__()

        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.hidden_channels = hidden_channels

        # -------------------------------------------------
        #  ENCODER: Conv + pooling + class conditioning
        # -------------------------------------------------

        # We inject class information by repeating a one-hot map
        # of shape (B, num_classes, 28, 28)
        in_channels = input_channels + num_classes
        conv_blocks = []

        for out_channels in hidden_channels:
            conv_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            conv_blocks.append(nn.ReLU())
            in_channels = out_channels

        self.encoder_conv = nn.Sequential(*conv_blocks)
        self.pool = nn.MaxPool2d(2, 2)  # downsample to 14×14

        # After pooling, size = last_channel × 14 × 14
        flat_dim = hidden_channels[-1] * 14 * 14

        self.fc_mu = nn.Linear(flat_dim, latent_dim)
        self.fc_logvar = nn.Linear(flat_dim, latent_dim)

        # -------------------------------------------------
        #  DECODER: MLP → ConvTranspose pipeline + conditioning
        # -------------------------------------------------

        self.fc_decode = nn.Linear(latent_dim + num_classes, flat_dim)

        decoder_channels = list(reversed(hidden_channels))

        deconv_blocks = []
        in_channels = decoder_channels[0]

        for i, out_channels in enumerate(decoder_channels[1:] + [input_channels]):

            # --- KEY POINT ---
            # Only the *last* hidden layer should perform upsampling:
            # 14x14 → 28x28
            if i == 0:
                # 1st deconv: keep size (14→14)
                deconv_blocks.append(
                    nn.ConvTranspose2d(
                        in_channels,
                        out_channels,
                        kernel_size=3,
                        stride=1,
                        padding=1,
                    )
                )

            elif i == len(decoder_channels) - 1:
                # FINAL deconv: upsample 14x14 → 28x28
                deconv_blocks.append(
                    nn.ConvTranspose2d(
                        in_channels,
                        out_channels,
                        kernel_size=4,
                        stride=2,
                        padding=1,          # clean ×2 upsampling
                    )
                )

            else:
                # Optional intermediate blocks (kept at 14×14)
                deconv_blocks.append(
                    nn.ConvTranspose2d(
                        in_channels,
                        out_channels,
                        kernel_size=3,
                        stride=1,
                        padding=1,
                    )
                )

            # add activation unless final output
            if out_channels != input_channels:
                deconv_blocks.append(nn.ReLU())

            in_channels = out_channels

        self.decoder_conv = nn.Sequential(*deconv_blocks)


    # -----------------------------------------------------
    #  Utility: convert label y → a one-hot feature map
    # -----------------------------------------------------

    def make_label_map(self, y, H, W):
        """
        Convert labels (B,) into a spatial tensor (B, num_classes, H, W)
        Each pixel has the same one-hot vector.
        """
        y_onehot = F.one_hot(y, num_classes=self.num_classes).float()
        return y_onehot[:, :, None, None].expand(-1, -1, H, W)

    # -----------------------------------------------------
    #  ENCODER
    # -----------------------------------------------------

    def encode(self, x, y):
        B, _, H, W = x.shape

        y_map = self.make_label_map(y, H, W)  # (B, C_class, 28, 28)
        inp = torch.cat([x, y_map], dim=1)

        h = self.encoder_conv(inp)
        h = self.pool(h)  # (B, C_last, 14, 14)

        h = torch.flatten(h, start_dim=1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    # -----------------------------------------------------
    #  DECODER
    # -----------------------------------------------------

    def decode(self, z, y):
        B = z.size(0)

        y_onehot = F.one_hot(y, num_classes=self.num_classes).float()
        latent_input = torch.cat([z, y_onehot], dim=1)

        h = self.fc_decode(latent_input)
        h = h.view(B, self.hidden_channels[-1], 14, 14)

        x_hat = self.decoder_conv(h)
        x_hat = torch.sigmoid(x_hat)
        return x_hat

    # -----------------------------------------------------
    #  FORWARD
    # -----------------------------------------------------

    def forward(self, x, y):
        mu, logvar = self.encode(x, y)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z, y)
        return x_hat, mu, logvar


In [None]:
def vae_loss(x_hat, x, mu, logvar):
    B = x.size(0)
    recon = F.binary_cross_entropy(x_hat, x, reduction='sum')
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return (recon + kl) / B

In [None]:
def train_cvae(model, dataloader, epochs=10, lr=1e-3, device="cpu"):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        total_loss = 0

        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            y_onehot = F.one_hot(y, num_classes=model.num_classes).float().to(device)

            optimizer.zero_grad()
            if isinstance(model, CVAE_CNN):
                x_hat, mu, logvar = model(x, y)
            elif isinstance(model, CVAE_MLP):
                x_hat, mu, logvar = model(x, y_onehot)
            loss = vae_loss(x_hat, x, mu, logvar)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1} | loss = {total_loss / len(dataloader.dataset):.4f}")


In [None]:
def generate_digit(model, n_samples=8, device="cpu"):
    all_imgs = []
    model.eval()
    for digit in range(10):
        with torch.no_grad():
            y = torch.full((n_samples,), digit, dtype=torch.long).to(device)
            y_onehot = F.one_hot(y, num_classes=10).float()
            z = torch.randn(n_samples, model.latent_dim).to(device)
            if isinstance(model, CVAE_CNN):
                imgs = model.decode(z, y).cpu()
            elif isinstance(model, CVAE_MLP):
                imgs = model.decode(z, y_onehot).cpu()
            all_imgs.append(imgs)
    
    scale = 0.7
    plt.figure(figsize=(n_samples*scale, 10*scale))
    for digit in range(10):
        imgs = all_imgs[digit]
        for i in range(n_samples):
            plt.subplot(10,n_samples,digit*n_samples+i+1)
            plt.imshow(imgs[i].squeeze(), cmap="gray")
            plt.axis("off")
    plt.show()

    return imgs


In [None]:
cvae_mlp = CVAE_MLP(hidden_dims=[400, 100], latent_dim=20)
train_cvae(cvae_mlp, train_loader, epochs=10, device=device)

In [None]:
all_imgs_mlp = generate_digit(cvae_mlp, n_samples=8, device=device)

In [None]:
cvae_cnn = CVAE_CNN(hidden_channels=[32, 64, 128], latent_dim=20)
train_cvae(cvae_cnn, train_loader, epochs=5, device=device)

In [None]:
all_imgs_cnn = generate_digit(cvae_cnn, n_samples=8, device=device)

## Text generation tasks

In [None]:
import torch
from transformers import BertTokenizer, BertForMaskedLM

# Load pretrained BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased").to(device)
model.eval()

# Example: mask a word
text_examples = [
    "The capital of France is [MASK].",
    "The largest planet in our solar system is [MASK].",
    "The most passionate programming language is [MASK]."
]

for text in text_examples:
    inputs = tokenizer(text, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        predictions = outputs.logits

    # Get the token id with the highest probability at the masked position
    mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
    predicted_token_id = predictions[0, mask_token_index, :].argmax(dim=-1)
    
    top_k = 5
    top_k_token_ids = predictions[0, mask_token_index, :].topk(top_k).indices
    top_k_token_probs = predictions[0, mask_token_index, :].topk(top_k).values
    str_top_k = [f"{tokenizer.decode([token_id])} ({prob.item():.1f})" for token_id, prob in zip(top_k_token_ids[0], top_k_token_probs[0])]

    predicted_token = tokenizer.decode(predicted_token_id)
    print("--------------------------------")
    print(f"Original text: {text}")
    print(f"Predicted token (top {top_k}): {str_top_k}")
    print(f"Filled sentence: {text.replace(tokenizer.mask_token, predicted_token)}")


In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Load GPT-2 tokenizer and model
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
gpt2_model.eval()

# Example prompt
prompt_examples = [
    "Once upon a time in a galaxy far, far away",
    "In the future, artificial intelligence will",
    "The secret to a happy life is"
]

for prompt in prompt_examples:
    inputs = gpt2_tokenizer(prompt, return_tensors="pt").to(device)

    # Generate text
    with torch.no_grad():
        output_ids = gpt2_model.generate(
            inputs["input_ids"],
            max_length=100,
            num_return_sequences=1,
            do_sample=True,
            top_k=50,
            top_p=0.95
        )

    generated_text = gpt2_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    print("--------------------------------")
    print("Prompt:", prompt)
    print("Generated:", generated_text)
