In [None]:
import torch
from model_pq import Generator , Discriminator, plot_images
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
torch.manual_seed(1)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
transformer = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.FashionMNIST(root=".", download=True, train=True, transform=transformer)
train_dataloader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, num_workers=4)

In [None]:
gen = Generator().to(device)
disc = Discriminator().to(device)

In [None]:
gen.load_state_dict(torch.load('generator_param.pth')) 
disc.load_state_dict(torch.load('discriminator_param.pth'))

In [None]:
n_inter = 10

def generate_latent(latent_dim, sample_dim):

    z_noise = torch.randn(sample_dim , latent_dim)

    return z_noise

def interpolate_points(p1, p2, n_inter=n_inter):

    ratios = torch.linspace(0, 1, steps=n_inter).reshape(-1, 1)
    
    vector = p1 * (1 - ratios) + ratios * p2
    
    return vector


s_g = generate_latent(100, 2)

interpolated = interpolate_points(s_g[0], s_g[1])

interpolated = interpolated.to(device)

for i in range(3):
    labels = torch.ones(n_inter) * i
    labels = labels.to(device)
    labels = labels.unsqueeze(1).long()

    prediction = gen((interpolated, labels))
    pred = prediction.detach().cpu()

    plot_images(pred, train_data.classes[i]) 
    