In [None]:
# Connect to drive
from google.colab import drive
drive.mount('/content/drive')
# go to folder

%cd /content/drive/MyDrive/consistency_main


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/.shortcut-targets-by-id/1mzejNCRGpGXYCyDkryrn2UyJYG4XMbVj/consistency_main


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from typing import List
from tqdm import tqdm
import math
from PIL import Image

import torch
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms
from torchvision.utils import save_image, make_grid

from consistency_models import ConsistencyModel, kerras_boundaries

In [None]:
# import sp500 data & slices_normalized data

sp500 = pd.read_csv('data/sp500.csv')
slices = pd.read_csv('data/sliced_data.csv')
slices_pca = pd.read_csv('data/sliced_data_pca.csv')

In [None]:
slices_pca.shape

(3809, 2)

In [None]:
def mnist_dl():
    tf = transforms.Compose(
        [
            transforms.Pad(2),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5)),
        ]
    )

    dataset = MNIST(
        "./data",
        train=True,
        download=True,
        transform=tf,
    )

    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=20)

    return dataloader


In [None]:
tf = transforms.Compose(
        [
            transforms.Pad(2),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5)),
        ]
    )
data = MNIST(
        "./data",
        train=True,
        download=True,
        transform=tf,
    )

In [None]:
data.data[0].shape

torch.Size([28, 28])

In [None]:
data.data[0]

tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   3,  18,
          18,  18, 126, 136, 175,  26, 166, 255, 247, 127,   0,   0,   0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   

In [None]:
slices.values

array([[1.25842308e-04, 3.16526323e-03, 3.58095414e-03, ...,
        2.61955138e-02, 2.58751747e-02, 3.16565599e-02],
       [3.16526323e-03, 3.58095414e-03, 1.85339096e-03, ...,
        2.58751747e-02, 3.16565599e-02, 3.12675809e-02],
       [3.58095414e-03, 1.85339096e-03, 2.28433801e-03, ...,
        3.16565599e-02, 3.12675809e-02, 3.23620703e-02],
       ...,
       [8.59381401e-01, 8.68151501e-01, 8.84556373e-01, ...,
        9.90978535e-01, 9.86254340e-01, 1.00000000e+00],
       [8.68151501e-01, 8.84556373e-01, 8.82491449e-01, ...,
        9.86254340e-01, 1.00000000e+00, 9.94033979e-01],
       [8.84556373e-01, 8.82491449e-01, 8.68684866e-01, ...,
        1.00000000e+00, 9.94033979e-01, 9.82657915e-01]])

In [None]:
def slices_dl():

    dataset = torch.tensor(slices.values).float()

    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=20)

    return dataloader


In [None]:
def cifar10_dl():
    tf = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    dataset = CIFAR10(
        "./data",
        train=True,
        download=True,
        transform=tf,
    )

    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=20)

    return dataloader

In [None]:
def train(
    n_epoch: int = 100,
    device="cuda:0",
    dataloader=mnist_dl(),
    n_channels=1,
    name="mnist",
):
    model = ConsistencyModel(n_channels, D=256)
    model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=2e-4)

    # Define \theta_{-}, which is EMA of the params
    ema_model = ConsistencyModel(n_channels, D=256)
    ema_model.to(device)
    ema_model.load_state_dict(model.state_dict())

    for epoch in range(1, n_epoch):
        N = math.ceil(math.sqrt((epoch * (150**2 - 4) / n_epoch) + 4) - 1) + 1
        boundaries = kerras_boundaries(7.0, 0.002, N, 80.0).to(device)

        pbar = tqdm(dataloader)
        loss_ema = None
        model.train()
        for x,_ in pbar:
            optim.zero_grad()
            x = x.to(device)

            z = torch.randn_like(x)
            t = torch.randint(0, N - 1, (x.shape[0], 1), device=device)
            t_0 = boundaries[t]
            t_1 = boundaries[t + 1]

            loss = model.loss(x, z, t_0, t_1, ema_model=ema_model)

            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.9 * loss_ema + 0.1 * loss.item()

            optim.step()
            with torch.no_grad():
                mu = math.exp(2 * math.log(0.95) / N)
                # update \theta_{-}
                for p, ema_p in zip(model.parameters(), ema_model.parameters()):
                    ema_p.mul_(mu).add_(p, alpha=1 - mu)

            pbar.set_description(f"Loss: {loss_ema:.10f}, Mu: {mu:.10f}")

        model.eval()
        with torch.no_grad():
            # Sample 10 Steps
            xh = model.sample(
                torch.randn_like(x).to(device=device) * 80.0,
                list(reversed([8.0, 16.0, 24.0, 32.0, 40.0, 48.0, 56.0, 64.0, 72.0, 80.0])),
            )
            xh = (xh * 0.5 + 0.5).clamp(0, 1)
            grid = make_grid(xh, nrow=4)
            save_image(grid, f"./contents/ct_{name}_sample_10step_{epoch}.png")

            # Sample 5 Steps
            xh = model.sample(
                torch.randn_like(x).to(device=device) * 80.0,
                list(reversed([5.0, 10.0, 20.0, 40.0, 80.0])),
            )
            xh = (xh * 0.5 + 0.5).clamp(0, 1)
            grid = make_grid(xh, nrow=4)
            save_image(grid, f"./contents/ct_{name}_sample_5step_{epoch}.png")

            # Sample 2 Steps
            xh = model.sample(
                torch.randn_like(x).to(device=device) * 80.0,
                list(reversed([2.0, 80.0])),
            )
            xh = (xh * 0.5 + 0.5).clamp(0, 1)
            grid = make_grid(xh, nrow=4)
            save_image(grid, f"./contents/ct_{name}_sample_2step_{epoch}.png")

            # save model
            torch.save(model.state_dict(), f"./ct_{name}.pth")




In [None]:
train()

  self.pid = os.fork()
Loss: 0.0035546355, Mu: 0.9936088490: 100%|██████████| 469/469 [00:41<00:00, 11.39it/s]
Loss: 0.0053500845, Mu: 0.9953478283: 100%|██████████| 469/469 [00:41<00:00, 11.39it/s]
Loss: 0.0028737190, Mu: 0.9962077057: 100%|██████████| 469/469 [00:41<00:00, 11.38it/s]
Loss: 0.0024182324, Mu: 0.9966962247: 100%|██████████| 469/469 [00:41<00:00, 11.42it/s]
Loss: 0.0025313237, Mu: 0.9969872947: 100%|██████████| 469/469 [00:40<00:00, 11.44it/s]
Loss: 0.0021712920, Mu: 0.9972312296: 100%|██████████| 469/469 [00:41<00:00, 11.42it/s]
Loss: 0.0021593276, Mu: 0.9974386212: 100%|██████████| 469/469 [00:41<00:00, 11.43it/s]
Loss: 0.0022381750, Mu: 0.9976171090: 100%|██████████| 469/469 [00:41<00:00, 11.43it/s]
Loss: 0.0019471995, Mu: 0.9977723417: 100%|██████████| 469/469 [00:41<00:00, 11.39it/s]
Loss: 0.0022901514, Mu: 0.9978650616: 100%|██████████| 469/469 [00:41<00:00, 11.43it/s]
Loss: 0.0018588313, Mu: 0.9979503716: 100%|██████████| 469/469 [00:41<00:00, 11.41it/s]
Loss: 0.0

In [None]:
pca_imgs = []
tf = transforms.Compose(
        [
            transforms.Pad(2),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5)),
        ]
    )
for idx in range(1000):
    img = Image.open(f'imgs/img_{idx}.png')
    img_array = tf(img)
    #img_array = np.array(img_array)
    # pad with white background to make square 720
    #pad = (img_array.shape[1] - img_array.shape[0])/2
    #img_array = np.pad(img_array, ((int(pad), int(pad)), (0, 0), (0, 0)), mode='constant', constant_values=255)
    # convert to 720x720x1 with only one channel with black and white
    #img_array = img_array[:, :, 0]
    #img_array = np.expand_dims(img_array, axis=2)
    #img_array = img_array.reshape(436, 436)
    print(img_array.shape)
    pca_imgs.append(img_array)


In [None]:
dataset = torch.tensor(pca_imgs)

In [None]:
dataset.data[0]

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], dtype=torch.float64)

In [None]:
dataset.data[0].shape

torch.Size([720, 720])

In [None]:
def pca_dl(dataset):

    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=20)

    return dataloader

In [None]:
train(n_epoch= 30,
    device="cuda:0",
    dataloader=cifar10_dl(),
    n_channels=3,
    name="cifar10")


Files already downloaded and verified


  self.pid = os.fork()
  self.pid = os.fork()
Loss: 0.0012061575, Mu: 0.9963428968: 100%|██████████| 391/391 [04:11<00:00,  1.56it/s]
Loss: 0.0004546793, Mu: 0.9973730312: 100%|██████████| 391/391 [04:21<00:00,  1.49it/s]
Loss: 0.0007240762, Mu: 0.9978650616: 100%|██████████| 391/391 [04:21<00:00,  1.50it/s]
Loss: 0.0002855310, Mu: 0.9981365277: 100%|██████████| 391/391 [04:21<00:00,  1.50it/s]
Loss: 0.0001934053, Mu: 0.9983467457: 100%|██████████| 391/391 [04:20<00:00,  1.50it/s]
Loss: 0.0001736533, Mu: 0.9984925111: 100%|██████████| 391/391 [04:20<00:00,  1.50it/s]
Loss: 0.0017928316, Mu: 0.9985956912: 100%|██████████| 391/391 [04:21<00:00,  1.50it/s]
Loss: 0.0001330570, Mu: 0.9986856518: 100%|██████████| 391/391 [04:22<00:00,  1.49it/s]
Loss: 0.0009260084, Mu: 0.9987647805: 100%|██████████| 391/391 [04:20<00:00,  1.50it/s]
Loss: 0.0006484732, Mu: 0.9988215387: 100%|██████████| 391/391 [04:20<00:00,  1.50it/s]
Loss: 0.0001458793, Mu: 0.9988733100: 100%|██████████| 391/391 [04:20<00:0