In [40]:
import torch
from data import load_traindata
device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset,DataLoader
from torch.utils.data import random_split
from models import Critic, Generator
from scipy.spatial.distance import directed_hausdorff
from augmentation import augment

mps


WGAN

In [41]:
torch.manual_seed(42)

<torch._C.Generator at 0x1144c84f0>

In [42]:
num_subclasses = 100
epochs = 10000
seq_size = 250
batch_size = 32
num_aug = 5
split_size=0.8
lr = 5e-5

In [43]:
X, _ = load_traindata(num_subclasses)
X = torch.tensor(X, dtype=torch.float32)
X = X.reshape(int(num_subclasses), int(1000 / seq_size), seq_size, 12)
X = X.reshape(int(num_subclasses * int(1000 / seq_size)), seq_size, 12)
train_size = int(split_size*len(X))
test_size = len(X) - train_size
X,X_test = random_split(X,[train_size, test_size])
X = augment(X)
X_input = X[:, :, 0]  # First channel
Y_target = X[:, :, 1]  # Second channel


X_test_tensors = [X_test.dataset[idx] for idx in range(len(X_test))] 
X_test_tensor = torch.stack(X_test_tensors) 
X_t = X_test_tensor[:, :, 0]  
Y_t = X_test_tensor[:, :, 1]

(9514, 28)


In [44]:
dataset = TensorDataset(X_input, Y_target)
test_dataset = TensorDataset(X_t, Y_t)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
generator = Generator().to(device)
critic = Critic(seq_size).to(device)
#criterion = nn.BCELoss()
optimizer_g = torch.optim.AdamW(generator.parameters(), lr=lr)
optimizer_c = torch.optim.AdamW(critic.parameters(), lr=lr)
print(sum(p.numel() for p in generator.parameters())/1e6, 'M parameters for Generator')
print(sum(p.numel() for p in critic.parameters())/1e6, 'M parameters for Discriminator')

1.398017 M parameters for Generator
0.861953 M parameters for Discriminator


In [45]:
counter = 0

In [46]:
def plotWave(X, Y, c=0):
    x_np = X.squeeze(0).detach().cpu().numpy()
    y_np = Y.squeeze(0).detach().cpu().numpy()

    # Plotting
    plt.figure(figsize=(20, 6))
    plt.plot(x_np, color='blue', label='X (Generated Signal)')
    plt.plot(y_np, color='red', label='Y (Real Signal)')
    plt.legend()
    plt.title(f'Wave Comparison - Frame {c}')
    plt.show()

    # MSE
    mse = np.mean((x_np - y_np) ** 2)
    # FD
    fd = max(directed_hausdorff(x_np.reshape(-1, 1), y_np.reshape(-1, 1))[0],
             directed_hausdorff(y_np.reshape(-1, 1), x_np.reshape(-1, 1))[0])

    print(f"Frame {c}:")
    print(f"Mean Squared Error (MSE): {mse:.4f}")
    print(f"Fréchet Distance: {fd:.4f}")

    return c + 1


In [47]:
def plot_losses(d_losses, g_losses):
    def smooth_curve(points, factor=0.9):
        smoothed = []
        for point in points:
            if smoothed:
                smoothed.append(smoothed[-1] * factor + point * (1 - factor))
            else:
                smoothed.append(point)
        return smoothed

    plt.plot(smooth_curve(d_losses), label='D Loss')
    plt.plot(smooth_curve(g_losses), label='G Loss')
    plt.legend()
    plt.show()

In [48]:
g_losses, d_losses = [], []

In [56]:
def compute_gradient_penalty(critic, real_samples, fake_samples, device):
    alpha = torch.rand(real_samples.size(0), 1, 1).to(device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    critic_interpolates = critic(interpolates)
    grad_outputs = torch.ones_like(critic_interpolates).to(device)
    gradients = torch.autograd.grad(
        outputs=critic_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [57]:
for epoch in range(epochs):
    for real_1, real_2 in dataloader:
        real_1, real_2 = real_1.to(device).float(), real_2.to(device).float()
        real_2 += 0.01 * torch.randn_like(real_2) #ADDING SOME NOISE
        for _ in range(3):
            optimizer_c.zero_grad()
            fake_2 = generator(real_1)
            real_scores = critic(real_2.unsqueeze(1))
            fake_scores = critic(fake_2.detach())  # Detach to avoid generator update
            #gradient_penalty
            
            gradient_penalty = compute_gradient_penalty(critic, real_2.unsqueeze(1), fake_2.detach(), device)    
            c_loss = torch.mean(fake_scores) - torch.mean(real_scores) + 10*gradient_penalty
            c_loss.backward()
            optimizer_c.step()
            
            # # Clip critic weights
            # for p in critic.parameters():
            #     p.data.clamp_(-clip_value, clip_value)
        
        optimizer_g.zero_grad()
        fake_2 = generator(real_1)
        g_loss = -torch.mean(critic(fake_2))
        g_loss.backward()
        optimizer_g.step()
    g_losses.append(g_loss.item())
    d_losses.append(c_loss.item())
    if (epoch+1)%10==0:
        print(epoch, gradient_penalty)
        plot_losses(d_losses, g_losses)
        with torch.no_grad():
            for real_1, real_2 in dataloader:
                real_1, real_2 = real_1.to(torch.float32),real_2.to(torch.float32)
                real_1, real_2 = real_1.to(device), real_2.to(device)
                fake_2 = generator(real_1)
                counter = plotWave(fake_2[0],real_2[0], counter)
                counter = plotWave(fake_2[12],real_2[12], counter) #random output
                break
    print(f"Epoch {epoch+1}/{epochs}, Critic Loss: {c_loss.item()}, Generator Loss: {g_loss.item()}")


Epoch 1/10000, Critic Loss: -0.21517981588840485, Generator Loss: -0.6548455953598022
Epoch 2/10000, Critic Loss: -0.35535651445388794, Generator Loss: -0.43051671981811523
Epoch 3/10000, Critic Loss: -0.5866774320602417, Generator Loss: 1.7068864107131958


KeyboardInterrupt: 