In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from model import build_model
import tqdm
from data import load_data
import matplotlib.pyplot as plt



In [None]:
LATENT_DIM = 256
BATCH_SIZE = 8

train_loader, test_loader = load_data(train_batch_size=BATCH_SIZE, test_batch_size=2)
model = build_model(LATENT_DIM, BATCH_SIZE)

In [None]:
def evaluate(model: nn.Module, num_data):
    model.eval()
    progress_bar = tqdm.tqdm(test_loader)
    generator_losses = []
    descriminator_losses = []
    for i, (images, _) in enumerate(progress_bar):
      with torch.no_grad():
        generator_loss, descriminator_loss = model(images)
        generator_losses.append(generator_loss)
        descriminator_losses.append(descriminator_loss)
        if i == num_data:
            break
    
    eval_generator_mean_loss = torch.tensor(generator_losses).mean()
    eval_descriminator_mean_loss = torch.tensor(descriminator_losses).mean()
    
    img = model.generate_image_batch(1).squeeze().detach().cpu()
    print(plt.imshow(img))
    
    model.train()
    
    return eval_descriminator_mean_loss, eval_generator_mean_loss
 
def train(model, num_epochs, eval_epoch):
    descriminator_optimizer = torch.optim.AdamW(model.descrimator.parameters(), lr=1e-3)
    generator_optimizer = torch.optim.AdamW(model.generator.parameters(), lr=1e-3)
    
    progress = tqdm.tqdm(train_loader, dynamic_ncols=True)
    for epoch in range(num_epochs):
        model.train()
        progress.set_description(f'Epoch: {epoch}')
        generator_losses = []
        descriminator_losses = []
        for images, _ in progress:
          generator_loss, descriminator_loss = model(images)
          descriminator_optimizer.zero_grad()
          descriminator_loss.backward()
          descriminator_optimizer.step()
          
          generator_optimizer.zero_grad()
          generator_loss.backward()
          generator_optimizer.step()
          
          generator_losses.append(generator_loss.item())
          descriminator_losses.append(descriminator_loss.item())
      
        generator_mean_loss = torch.tensor(generator_losses).mean()
        descriminator_mean_loss = torch.tensor(descriminator_losses).mean()
       
        if (epoch+1) % eval_epoch == 0:
          eval_descriminator_loss, eval_generator_loss = evaluate(model, num_data=20)
          progress.set_postfix({'generator_loss': generator_mean_loss, 'descriminator_loss': descriminator_mean_loss, 'eval_gen_loss': eval_generator_loss, 'eval_des_loss':eval_descriminator_loss})
        else:
         progress.set_postfix({'generator_loss': generator_mean_loss, 'descriminator_loss': descriminator_mean_loss})
       
      
          
    

In [5]:

train(generator=generator, descriminator=descriminator, num_epochs=20, eval_epoch=2)

Epoch: 0:   0%|          | 0/7500 [00:00<?, ?it/s]


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.