In [25]:
#!pip install medmnist

In [26]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import os
import time
import imageio

import numpy as np
import matplotlib.pyplot as plt

from IPython.display import clear_output

from medmnist import DermaMNIST
from medmnist.info import INFO

from torch.utils.data import DataLoader
from tqdm import tqdm

In [27]:
def load_emoji(index, path="datasets/emoji/emoji.png"):
    im = imageio.imread(path)
    emoji = np.array(im[:, index*40:(index+1)*40].astype(np.float32))
    emoji /= 255.0
    return emoji

In [28]:
class SamplePool:
    def __init__(self, *, _parent=None, _parent_idx=None, **slots):
        self._parent = _parent
        self._parent_idx = _parent_idx
        self._slot_names = slots.keys()
        self._size = None
        for k, v in slots.items():
            if self._size is None:
                self._size = len(v)
            assert self._size == len(v)
            setattr(self, k, np.asarray(v))

    def sample(self, n):
        idx = np.random.choice(self._size, n, False)
        batch = {k: getattr(self, k)[idx] for k in self._slot_names}
        batch = SamplePool(**batch, _parent=self, _parent_idx=idx)
        return batch

    def commit(self):
        for k in self._slot_names:
            getattr(self._parent, k)[self._parent_idx] = getattr(self, k)

def make_seed(shape, n_channels):
    seed = np.zeros([shape[0], shape[1], n_channels], np.float32)
    seed[shape[0]//2, shape[1]//2, 3:] = 1.0
    return seed

In [29]:
class CAModel(nn.Module):
    def __init__(self, channel_n, fire_rate, device, hidden_size=128):
        super(CAModel, self).__init__()

        self.device = device
        self.channel_n = channel_n

        self.fc0 = nn.Linear(channel_n*3, hidden_size)
        self.fc1 = nn.Linear(hidden_size, channel_n, bias=False)
        with torch.no_grad():
            self.fc1.weight.zero_()

        self.fire_rate = fire_rate
        self.to(self.device)

    def alive(self, x):
        return F.max_pool2d(x[:, 3:4, :, :], kernel_size=3, stride=1, padding=1) > 0.1

    def perceive(self, x, angle):

        def _perceive_with(x, weight):
            conv_weights = torch.from_numpy(weight.astype(np.float32)).to(self.device)
            conv_weights = conv_weights.view(1,1,3,3).repeat(self.channel_n, 1, 1, 1)
            return F.conv2d(x, conv_weights, padding=1, groups=self.channel_n)

        dx = np.outer([1, 2, 1], [-1, 0, 1]) / 8.0  # Sobel filter
        dy = dx.T
        c = np.cos(angle*np.pi/180)
        s = np.sin(angle*np.pi/180)
        w1 = c*dx-s*dy
        w2 = s*dx+c*dy

        y1 = _perceive_with(x, w1)
        y2 = _perceive_with(x, w2)
        y = torch.cat((x,y1,y2),1)
        return y

    def update(self, x, fire_rate, angle):
        x = x.transpose(1,3)
        pre_life_mask = self.alive(x)

        dx = self.perceive(x, angle)
        dx = dx.transpose(1,3)
        dx = self.fc0(dx)
        dx = F.relu(dx)
        dx = self.fc1(dx)

        if fire_rate is None:
            fire_rate=self.fire_rate
        stochastic = torch.rand([dx.size(0),dx.size(1),dx.size(2),1])>fire_rate
        stochastic = stochastic.float().to(self.device)
        dx = dx * stochastic

        x = x+dx.transpose(1,3)

        post_life_mask = self.alive(x)
        life_mask = (pre_life_mask & post_life_mask).float()
        x = x * life_mask
        return x.transpose(1,3)

    def forward(self, x, steps=1, fire_rate=None, angle=0.0):
        for step in range(steps):
            x = self.update(x, fire_rate, angle)
        return x


In [30]:
USE_WANDB = False

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = "remaster_1.pth"

CHANNEL_N = 16        # Number of CA state channels
TARGET_PADDING = 0    # 16   # Number of pixels used to pad the target image border
TARGET_SIZE = 64

lr = 2e-3
lr_gamma = 0.9999
betas = (0.5, 0.5)
n_epoch = 1000

BATCH_SIZE = 8
POOL_SIZE = 1024
CELL_FIRE_RATE = 0.5

EXPERIMENT_TYPE = "Growing"
EXPERIMENT_MAP = {"Growing":0, "Persistent":1, "Regenerating":2}
EXPERIMENT_N = EXPERIMENT_MAP[EXPERIMENT_TYPE]

USE_PATTERN_POOL = [0, 1, 1][EXPERIMENT_N]
DERMAMNIST_CLASSES = INFO["dermamnist"]["label"]

dermaMnist_dataset = DermaMNIST(split="train", download=True, as_rgb=True, size=64)

melanoma_samples = []
np.random.seed(seed=42)
for i, sample in enumerate(dermaMnist_dataset):
    if sample[1][0] == 4:
        img_rgba = sample[0].convert("RGBA")
        melanoma_samples.append(np.array(img_rgba, dtype=np.float32) / 255)
print(f"dataset length {len(melanoma_samples)}")

np.random.shuffle(melanoma_samples)
target_img = melanoma_samples

initial_images = torch.rand(len(melanoma_samples), 64, 64, 16)
initial_images[:, :, :, 4:] = 0

# plt.figure(figsize=(4,4))
# plt.imshow(target_img)
# plt.show()


p = TARGET_PADDING
pad_target = np.pad(target_img, [(0,0), (p, p), (p, p), (0, 0)])
h, w = pad_target.shape[1:3]
# pad_target = np.expand_dims(pad_target, axis=0)
pad_target = torch.from_numpy(pad_target.astype(np.float32)).to(device)

train_dataloader = DataLoader(pad_target, batch_size=BATCH_SIZE, shuffle=False)

seed = make_seed((h, w), CHANNEL_N)
pool = SamplePool(x=np.repeat(seed[None, ...], POOL_SIZE, 0))
batch = pool.sample(BATCH_SIZE).x

ca = CAModel(CHANNEL_N, CELL_FIRE_RATE, device).to(device)
# ca.load_state_dict(torch.load(model_path, map_location=device))

optimizer = optim.Adam(ca.parameters(), lr=lr, betas=betas)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, lr_gamma)

if USE_WANDB:
    import wandb
    import secret

    os.environ["WANDB_API_KEY"] = secret.key
    wandb.init(project="GrowingNCA")
    wandb.watch(ca, log='gradients', log_freq=BATCH_SIZE)
    wandb.watch(ca, log='parameters', log_freq=BATCH_SIZE)

loss_log = []

def train(x, target, steps, optimizer, scheduler):
    x = ca(x, steps=steps)
    loss = F.mse_loss(x[:, :, :, :4], target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    return x, loss

def loss_f(x, target):
    return torch.mean(torch.pow(x[..., :4]-target, 2), [-2,-3,-1])

for i in tqdm(range(1, n_epoch+1)):
    for j, pad_target in enumerate(train_dataloader):
        if USE_PATTERN_POOL:
            batch = pool.sample(pad_target.shape[0])
            x0 = torch.from_numpy(batch.x.astype(np.float32)).to(device)
            loss_rank = loss_f(x0, pad_target).detach().cpu().numpy().argsort()[::-1]
            x0 = batch.x[loss_rank]
            x0[:1] = seed

        else:
            x0 = np.repeat(seed[None, ...], BATCH_SIZE, 0)
            x0 = initial_images[j*BATCH_SIZE:(j*BATCH_SIZE) + len(pad_target)]
        x0 = x0.to(device)

        x, loss = train(x0, pad_target, 96, optimizer, scheduler)

        if USE_WANDB:
            wandb.log({'model_loss': loss.item()})

        if USE_PATTERN_POOL:
            batch.x[:] = x.detach().cpu().numpy()
            batch.commit()

        step_i = len(loss_log)
        loss_log.append(loss.item())

    if i%50 == 0:  # step_i%100 == 0:
        # clear_output()
        print(step_i, "loss =", loss.item())
        # visualize_batch(x0.detach().cpu().numpy(), x.detach().cpu().numpy())
        # plot_loss(loss_log)
        torch.save(ca.state_dict(), model_path + f"_{i}")

if USE_WANDB:
    wandb.finish()


Using downloaded and verified file: C:\Users\Niclas\.medmnist\dermamnist_64.npz
dataset length 779


  0%|          | 3/1000 [00:34<3:13:40, 11.66s/it]