In [14]:
#@title imports
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import functools
import numpy as np
import time
import torch

## Get the data

In [3]:
#@title func: get_dataloader
def get_dataloader(dataset, batch_size: int = 16, shuffle: bool = True) -> DataLoader:
  return DataLoader(
      dataset=dataset,
      batch_size=batch_size,
      shuffle=shuffle
  )

In [4]:
#@title Load the datasets and create dataloaders
batch_size = 32  #@param
shuffle = True  #@param
download = True  #@param

image_transforms = transforms.Compose([
  transforms.Pad(padding=2, fill=0, padding_mode="constant"),
  transforms.ToTensor(),
  transforms.Normalize((0.5), (.5)),
])

dataset_train = MNIST(root="data/MNIST/train", train=True, download=download, transform=image_transforms)
dataset_test = MNIST(root="data/MNIST/test", train=False, download=download, transform=image_transforms)

dataloader_train = get_dataloader(dataset_train, batch_size=batch_size, shuffle=shuffle)
dataloader_test = get_dataloader(dataset_test, batch_size=batch_size, shuffle=shuffle, )

print(f"Train samples: {len(dataset_train)} ; {len(dataloader_train)} batches")
print(f"Test samples: {len(dataset_test)} ; {len(dataloader_test)} batches")

Train samples: 60000 ; 1875 batches
Test samples: 10000 ; 313 batches


In [5]:
print(dataset_train[0][0].mean(), dataset_train[0][0].std())
print(dataset_train[1][0].min(), dataset_train[1][0].max())

tensor(-0.7892) tensor(0.5592)
tensor(-1.) tensor(1.)


In [6]:
#@title classes: Encoder, CNNDiscriminator
class Encoder(torch.nn.Module):
    def __init__(self, n_c_in, n_c_out, p_dropout=0):
        super().__init__()
        self.b1 = torch.nn.Sequential(
            torch.nn.Conv2d(n_c_in, n_c_out, kernel_size=3, stride=1, padding=1),
            # torch.nn.BatchNorm2d(n_c_out),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Dropout(p_dropout),
        )
        # self.b2 = torch.nn.Sequential(
        #     torch.nn.Conv2d(n_c_out, n_c_out, kernel_size=3, stride=1, padding=1),
        #     torch.nn.BatchNorm2d(n_c_out),
        #     torch.nn.ReLU(inplace=True),
        #     torch.nn.Dropout(p_dropout),
        # )
        # self.b3 = torch.nn.Sequential(
        #     torch.nn.Conv2d(n_c_out, n_c_out, kernel_size=3, stride=1, padding=1),
        #     torch.nn.BatchNorm2d(n_c_out),
        #     torch.nn.ReLU(inplace=True),
        #     torch.nn.Dropout(p_dropout),
        # )

    def forward(self, x):
        out1 = self.b1(x)
        return out1
        # r1 = out1  # no residual since input is different size
        # out2 = self.b2(r1)
        # r2 = r1 + out2
        # out3 = self.b3(r2)
        # r3 = r2 + out3
        # return r3

class CNNDiscriminator(torch.nn.Module):
  """
  A very simple CNN with "resnet" style blocks (not even sure if they truly are)
  """
  def __init__(self, n_channels, n_classes):
    super().__init__()
    self.encoder1 = Encoder(n_channels, 16)
    self.encoder2 = Encoder(16, 32)
    self.encoder3 = Encoder(32, 64)
    self.classifier = torch.nn.Linear(16 * 64, n_classes)

  @staticmethod
  def _encode_and_pool(encoder, inputs):
    encoded = encoder(inputs)
    pooled = torch.nn.functional.max_pool2d(encoded, kernel_size=2, stride=2)
    return encoded, pooled

  def forward(self, x):
    encoded1, pooled1 = self._encode_and_pool(self.encoder1, x)
    encoded2, pooled2 = self._encode_and_pool(self.encoder2, pooled1)
    encoded3, pooled3 = self._encode_and_pool(self.encoder3, pooled2)
    activations = self.classifier(pooled3.view(pooled3.shape[0], -1))
    probs = torch.nn.functional.softmax(activations, dim=-1)
    log_probs = torch.nn.functional.log_softmax(activations, dim=-1)
    return log_probs, probs

In [7]:
device = "cpu"
if torch.cuda.is_available():
  device = "cuda:0"

In [None]:
#@title MNIST model,loss,optim
model_mnist = CNNDiscriminator(1, 10).to(device)
n_params = sum([np.prod(param.size()) for param in model_mnist.parameters()])
print(f"{n_params} parameters")

# Test the model
images, labels = next(iter(dataloader_train))
print(images.shape, labels.shape)
_, pred_probs = model_mnist(images.to(device))
print(pred_probs.shape)

# Loss & optim
loss_mnist = torch.nn.NLLLoss().to(device)
optimizer_mnist = torch.optim.Adam(params=model_mnist.parameters())

33546 parameters
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 10])


In [9]:
#@title visualizing funcs
# import numpy as np
# import matplotlib.pyplot as plt
# from skimage.util import montage


# def show_img(im, figsize=None, ax=None, title=None):
#     import matplotlib.pyplot as plt
#     if not ax: fig, ax = plt.subplots(figsize=figsize)
#     ax.imshow(im, cmap='gray')
#     if title is not None: ax.set_title(title, fontsize=50)
#     ax.get_xaxis().set_visible(False)
#     ax.get_yaxis().set_visible(False)
#     return ax


# def draw_rect(ax, bbox, edgecolor='red'):
#     import matplotlib.patches as patches
#     x, y, w, h = bbox
#     patch = ax.add_patch(patches.Rectangle((x, y), w, h, fill=False, edgecolor=edgecolor, lw=2))


# def draw_canvas(img, bboxes: np.ndarray, color='red'):
#     fig, ax = plt.subplots(1, 1, figsize=(192 / 20, 108 / 20))

#     for ix in range(len(bboxes)):
#         bbox = bboxes[ix]

#         draw_rect(ax, bbox, edgecolor=color)  # will add red bounding boxes for each character

#     ax = show_img(img, ax=ax)
#     plt.xticks([])
#     plt.yticks([])


# def visualize_data(dataloader, n=5, fields_to_print=None):
#     for ix, data in enumerate(dataloader):
#         if ix >= n:
#             break

#         img = data["image"][0].data.cpu().numpy().transpose(1, 2, 0)
#         dist = data["dists"][0].data.cpu().numpy().astype(np.uint8)
#         word = data["labels"][0].data.cpu().numpy().astype(np.float32)

#         if fields_to_print is not None:
#             for field in fields_to_print:
#                 print(data[field])

#         # Remove the -100s in the padding
#         pad_idx = np.where(word == -100)
#         dist[pad_idx] = MAX_DIST - 1
#         word[pad_idx] = 0

#         print(img.shape, dist.shape)
#         plt.figure(figsize=(20, 5))
#         plt.imshow(img.astype(np.uint8))
#         plt.show()
#         plt.figure(figsize=(20, 5))
#         plt.title(np.unique(dist))
#         plt.imshow(dist.astype(np.uint8))
#         plt.show()
#         plt.figure(figsize=(20, 5))
#         plt.title(np.unique(word))
#         plt.imshow(word.astype(np.float32))
#         plt.show()
#         print('--------------')


# def visualize_data_outputs(images, target_label, target_dist, pred_labels, pred_dists):
#     pad_idxs = np.where(target_dist == -100)
#     target_label[pad_idxs] = 0
#     pred_labels[pad_idxs] = 0
#     target_dist[pad_idxs] = MAX_DIST - 1
#     pred_dists[pad_idxs] = MAX_DIST - 1

#     images = images.astype(np.int)

#     ndim = images.shape[0]
#     if ndim > 1:
#         images = montage(images, multichannel=True)
#         target_dist = montage(target_dist)
#         target_label = montage(target_label)
#         pred_labels = montage(pred_labels)
#         pred_dists = montage(pred_dists)
#     else:
#         images = images[0]
#         target_dist = target_dist[0]
#         target_label = target_label[0]
#         pred_labels = pred_labels[0]
#         pred_dists = pred_dists[0]

#     N = 5
#     plt.figure(figsize=(20, 10))
#     plt.subplot(1, N, 1)
#     plt.imshow(images)
#     plt.title('Input images')

#     plt.subplot(1, N, 2)
#     plt.imshow(target_label)
#     plt.title('Target labels')

#     plt.subplot(1, N, 3)
#     plt.imshow(pred_labels)
#     plt.title('Predicted labels')

#     plt.subplot(1, N, 4)
#     plt.imshow(target_dist)
#     plt.title('Target dists')

#     plt.subplot(1, N, 5)
#     plt.imshow(pred_dists)
#     plt.title('Predicted dists')

#     plt.show()

In [15]:
version = time.strftime("%Y%m%d-%H%M%S")
log_dir = f"zcode/tb-logs/vanilla-gan-{version}"
logger_d = SummaryWriter(logdir=log_dir)
print(f"logging to {log_dir}")

logging to zcode/tb-logs/vanilla-gan-20250429-153005


In [17]:
%load_ext tensorboard
%tensorboard --logdir "{log_dir}"

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6007 (pid 73673), started 0:00:02 ago. (Use '!kill 73673' to kill it.)

## Train / val functions

In [18]:
#@title func: validate
def validate(dataloader, n_batches, model, loss_func, log_func=None):
  model.train(False)
  model.eval()
  device = next(model.parameters()).device

  loss = 0
  n_correct, n_total = 0, 0
  for ix, batch_val in enumerate(dataloader):
    if ix >= n_batches: break
    images_val, labels_val = batch_val

    # Get the loss
    log_probs, _ = model(images_val.to(device))
    loss += loss_func(log_probs, labels_val.to(device)).data.cpu().numpy()

    # Get the acc
    preds = np.argmax(log_probs.data.cpu().numpy(), axis=-1)
    n_correct += np.sum(preds == labels_val)
    n_total += labels_val.shape[0]

    # if ix == 0:
    #   image_montage = montage(images_val.squeeze().data.cpu().numpy())
    #   plt.figure(figsize=(5, 5))
    #   plt.subplot(1, 1, 1)
    #   plt.imshow(image_montage)
    #   plt.title(f"Labels: {labels_val.data.cpu().numpy()}\nPreds: {preds}")
    #   plt.show()

  model.train(True)
  loss /= n_batches
  if log_func:
    log_func(tag="disc.loss.val", scalar_value=loss)
    log_func(tag="disc.acc.val", scalar_value=n_correct/n_total)
  return loss

In [19]:
#@title func: train
import math


def train(dataloader, n_epochs, model, loss_func, optimizer,
          logger, callback_frequency=100, callbacks=None,
          es_smoothing=0.6, es_threshold=1e-8
          ) -> None:
  # Max steps to run for (in case of fractional n_epochs)
  max_steps = len(dataloader)
  if n_epochs < 1:
    max_steps = math.ceil(n_epochs * max_steps)
    n_epochs = 1

  if callbacks is None:
    callbacks = []

  device = next(model.parameters()).device

  # Running loss - for early stopping
  average_loss = 0

  for epoch in range(n_epochs):
    for ix, batch_train in enumerate(dataloader):
      step = ix + (epoch * len(dataloader))
      if ix >= max_steps:
        break
    
      # Get a train batch
      images_train, labels_train = batch_train
      
      # Get preds
      pred_log_probs_train, pred_probs_train = model(images_train.to(device))
      
      # calculate loss
      optimizer.zero_grad()
      loss_d_train = loss_func(pred_log_probs_train, labels_train.to(device))
      
      # backward
      loss_d_train.backward()
      optimizer.step()
    
      loss_val = loss_d_train.data.cpu().numpy()
      # tensorboard logging
      log_func = functools.partial(logger.add_scalar, global_step=step)
      log_func(tag="disc.loss.train", scalar_value=loss_val)

      # early stopping
      if np.abs(loss_val - average_loss) < es_threshold:
        print(f"Stopped early at iteration {step} with {loss_val}, average: {average_loss}")
        return
      average_loss = average_loss * es_smoothing + loss_val * (1 - es_smoothing)

      # Call the callbacks!
      if ix % callback_frequency == 0:
        _ = [f(model=model, log_func=log_func, loss_func=loss_func) for f in callbacks]

In [20]:
#@title class: Deconv, Generator
class Deconv(torch.nn.Module):
    def __init__(self, n_c_in, n_c_out):
        super().__init__()
        self.b1 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(n_c_in, n_c_out, kernel_size=2, stride=2, padding=0),
            torch.nn.BatchNorm2d(n_c_out),
            torch.nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.b1(x)

class Generator(torch.nn.Module):
    def __init__(self, p_dropout=0):
        super().__init__()
        # deconv are the set of layers that enlarge the encoded image to double size
        self.deconv1 = Deconv(64, 32)
        self.deconv2 = Deconv(32, 32)
        self.deconv3 = Deconv(32, 16)
        # decoder uses deconvolution to creates the segmented image from encoded images using U-net structure
        self.decoder1 = Encoder(32, 32, p_dropout)
        self.decoder2 = Encoder(32, 32, p_dropout)
        self.decoder3 = Encoder(16, 16, p_dropout)
        self.word_pred = torch.nn.Sequential(
            torch.nn.Conv2d(16, 1, kernel_size=(1, 1), stride=1),
        )

    def forward(self, inputs):
        deconved1 = self.deconv1(inputs)
        decoded1 = self.decoder1(deconved1)
        deconved2 = self.deconv2(decoded1)
        decoded2 = self.decoder2(deconved2)
        deconved3 = self.deconv3(decoded2)
        decoded3 = self.decoder3(deconved3)
        activations = self.word_pred(decoded3)
        return torch.tanh(activations)

In [21]:
#@title Generator configs
batch_size_g =   32#@param
n_epochs_g =   1#@param

n_epochs_d =   1#@param
batch_size_d = 32  #@param
n_cycles = 700  #@param

half_bsz = batch_size_d // 2

In [22]:
#@title class: RealGenDataset
class RealGenDataset(torch.utils.data.Dataset):
  """
  Discriminator Labels:
    - Generated: 1
    - Real: 0
  
  If using with "only_generated", I flip the labels:
    - Generated: 0
  
  This is simply done for easy code reuse while training the generator since
  we want the discriminator to think that the generated images are real
  """

  def __init__(self, dataset_real, model_gen, only_generated=False):
    self.dataset = dataset_real
    self.generator = model_gen
    self.device = next(model_gen.parameters()).device
    self.only_generated = only_generated
  
  def __len__(self):
    return len(self.dataset)
  
  def __getitem__(self, ix, rng=None):
    image_real, _ = self.dataset[ix]

    if rng is not None:
      np.random.seed(rng)
    
    z = torch.from_numpy(np.random.randn(1, 64, 4, 4).astype(np.float32))
    # Note: still attached to the graph
    image_gen = self.generator(z.to(self.device)).cpu() 

    if self.only_generated:
      return image_gen, torch.from_numpy(np.array([0]).astype(np.long))

    images = torch.cat([image_real.view(1, *image_real.shape), image_gen], dim=0)
    labels = torch.from_numpy(np.array([0, 1]).astype(np.long))
    return (images, labels)

def collate_batches(batches):
  images = torch.cat([b[0] for b in batches], dim=0)
  labels = torch.cat([b[1] for b in batches], dim=0)
  return images, labels

In [23]:
model_g = Generator().to(device)
model_d = CNNDiscriminator(n_channels=1, n_classes=2).to(device)

In [24]:
#@title get dataset_train_d dataloader_train_d dataset_test_d dataloader_test_d

dataset_train_d = RealGenDataset(dataset_train, model_g)
dataloader_train_d = DataLoader(
    dataset_train_d,
    batch_size=half_bsz,
    shuffle=True,
    collate_fn=collate_batches
    )

dataset_test_d = RealGenDataset(dataset_test, model_g)
dataloader_test_d = DataLoader(
    dataset_test_d,
    batch_size=half_bsz,
    shuffle=True,
    collate_fn=collate_batches
    )

# Check if it makes sense
images, labels = next(iter(dataloader_train_d))
print(images.shape, labels.shape, labels)

torch.Size([32, 1, 32, 32]) torch.Size([32]) tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
        0, 1, 0, 1, 0, 1, 0, 1])


In [25]:
#@title get dataset_train_g dataloader_train_g dataset_test_g dataloader_test_g

dataset_train_g = RealGenDataset(dataset_train, model_g, only_generated=True)
dataloader_train_g = DataLoader(
    dataset_train_g,
    batch_size=batch_size_g,
    shuffle=True,
    collate_fn=collate_batches
    )

dataset_test_g = RealGenDataset(dataset_test, model_g, only_generated=True)
dataloader_test_g = DataLoader(
    dataset_test_g,
    batch_size=batch_size_g,
    shuffle=True,
    collate_fn=collate_batches
    )

# Check if it makes sense
images, labels = next(iter(dataloader_train_g))
print(images.shape, labels.shape, labels)

torch.Size([32, 1, 32, 32]) torch.Size([32]) tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])


In [26]:
#@title loss functions and optimizers
loss_d = torch.nn.NLLLoss()
loss_g = torch.nn.NLLLoss()
optimizer_d = torch.optim.Adam(model_d.parameters())
optimizer_g = torch.optim.Adam(model_g.parameters())

In [27]:
def train_one_batch(batch, model, optimizer, loss_func):
  # Get a train batch
  images_train, labels_train = batch
  
  # Get preds
  pred_log_probs_train, pred_probs_train = model(images_train.to(device))
  
  # calculate loss
  optimizer.zero_grad()
  loss_d_train = loss_func(pred_log_probs_train, labels_train.to(device))
  
  # backward
  loss_d_train.backward()
  optimizer.step()

  return loss_d_train

In [28]:
def print_foo():
  print(f"G: {next(model_g.parameters()).mean(), next(model_g.parameters()).std()}")
  print(f"D: {next(model_d.parameters()).mean(), next(model_d.parameters()).std()}")

In [29]:
#@title train gan
import time

val_func_d = functools.partial(validate, dataloader=dataloader_test_d, n_batches=5)
val_func_g = functools.partial(validate, dataloader=dataloader_test_g, n_batches=5)

step =   0   ##@param
start_time = time.time()
while True:
  step += 1

  # Train the discrimator on one batch
  model_d.train(True)
  model_g.train(False)
  loss_train_d = train_one_batch(
      batch=next(iter(dataloader_train_d)),
      model=model_d,
      optimizer=optimizer_d,
      loss_func=loss_d
  )
  logger_d.add_scalar("dis.loss.train", loss_train_d.data.cpu().numpy(), step)

  # Train the generator on one batch
  model_d.train(False)
  model_g.train(True)
  loss_train_g = train_one_batch(
      batch=next(iter(dataloader_train_g)),
      model=model_d,
      optimizer=optimizer_g,
      loss_func=loss_g
  )
  logger_d.add_scalar("gen.loss.train", loss_train_g.data.cpu().numpy(), step)

  # # Validation 
  # if step % 200 == 0:
  #   loss_val_g = val_func_g(model=model_d, loss_func=loss_g)
  #   loss_val_d = val_func_d(model=model_d, loss_func=loss_d)
  #   logger_d.add_scalar("gen.loss.val", loss_val_g, step)
  #   logger_d.add_scalar("dis.loss.val", loss_val_d, step)
  
  # # Save
  # if step % 2000 == 0:
  #   torch.save(model_d, log_dir + f"model_d_{step}.pt")
  #   torch.save(model_g, log_dir + f"model_g_{step}.pt")
  #   print(f"seconds elapsed since last checkpoint: {time.time() - start_time}")
  #   start_time = time.time()

KeyboardInterrupt: 

In [30]:
torch.save(model_d, log_dir + f"model_d_{step}.pt")
torch.save(model_g, log_dir + f"model_g_{step}.pt")
print(f"seconds elapsed since last checkpoint: {time.time() - start_time}")
start_time = time.time()

seconds elapsed since last checkpoint: 7325.823907852173


In [31]:
step

44769