# GAN

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import fastai
import torch
from fastai.vision.all import *
from fastai.vision.gan import *
from torch import nn

In [3]:
print(fastai.__version__) # version check

2.7.15


### Data

In [4]:
embedding = torch.load('embedding.pkl')

In [6]:
# Custom dataset class
class Txt2ImgDataset(Dataset):
    def __init__(self, items, embedding=embedding, image_size=128):
        self.items = items
        self.embedding = embedding
        self.image_size = image_size

    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        fn = self.items[idx]
        key = fn.name.split('.')[0] + '.txt'
        emb = self.embedding[key].cuda()
        img = PILImage.create(fn).resize((self.image_size, self.image_size))
        wrong_img = self.get_wrong_image(idx)
        noise = torch.randn(100, 1, 1).cuda()
        emb_with_noise = torch.cat([emb, noise], 0)
        return emb_with_noise, img, wrong_img

    def get_wrong_image(self, idx):
        wrong_idx = random.choice(range(len(self.items)))
        while wrong_idx == idx:
            wrong_idx = random.choice(range(len(self.items)))
        wrong_fn = self.items[wrong_idx]
        wrong_img = PILImage.create(wrong_fn).resize((self.image_size, self.image_size))
        return wrong_img


NameError: name 'Dataset' is not defined

In [5]:
# Create DataLoaders
def get_dls(path, bs, size):
    items = get_image_files(path)
    dset = Txt2ImgDataset(items, image_size=size)
    return DataLoaders.from_dsets(dset, dset, bs=bs, shuffle=True)

dls = get_dls('data/images/', bs=64, size=128)

NameError: name 'Txt2ImgDataset' is not defined

In [None]:
class Txt2ImgYTransform(Transform):
    def __init__(self, embedding=embedding, **kwargs):
        self.embedding = embedding
    
    def encodes(self, fn):
        key = fn.stem + '.txt'
        img = PILImage.create(fn)
        return (TensorImage(self.embedding[key]).cuda(), img)
    
    def decodes(self, x):
        return PILImage(x[1].float().clamp(min=0, max=1))
        
    def show(self, xs, ys=None, zs=None, imgsize=4, figsize=None, **kwargs):
        raise NotImplementedError

In [None]:
class Txt2ImgXTransform(Transform):
    def __init__(self, embedding=embedding, **kwargs):
        self.embedding = embedding
    
    def encodes(self, fn):
        key = fn.stem + '.txt'
        img = self.get_wrong_image(fn)
        return (TensorImage(self.embedding[key]).cuda(), img)
    
    def get_wrong_image(self, fn):
        cat = '_'.join(fn.name.split('_')[:-3])
        items = fn.ls()  # List of all items in the directory
        idx = np.random.randint(len(items))
        
        while items[idx].name.startswith(cat):
            idx = np.random.randint(len(items))
        
        return PILImage.create(items[idx])

    def decodes(self, x):
        return PILImage(x[1].float().clamp(min=0, max=1))
        
    def show(self, xs, ys=None, zs=None, imgsize=4, figsize=None, **kwargs):
        raise NotImplementedError

In [None]:
batch_size = 128

In [None]:
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   get_items=get_image_files,
                   splitter=FuncSplitter(lambda x: False),
                   get_y=noop)

data = dblock.dataloaders('data/images/train', bs=batch_size)

In [None]:
item = next(iter(data.train))

In [None]:
item[0].shape, item[0].shape, item[1].shape, item[1].shape 

## recheck the dataset 

## Model

In [None]:
def avg_flatten(x): return x.mean(0).view(1)

def AvgFlatten(): return Lambda(avg_flatten)  # now can pickle

In [None]:
in_size = 128

In [None]:
def squeezer(in_dim, out_dim):
    return nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.BatchNorm1d(out_dim),
            nn.LeakyReLU(0.2, inplace=True)
        )

In [None]:
class imageGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.generator = basic_generator(in_size=in_size, n_channels=3, noise_sz=228)
        self.squeezer = squeezer(400, 128)
    
    def forward(self, embedding, fake_image=None):        
        em_s = self.squeezer(embedding.view(embedding.size(0), -1))
        em_s = em_s[:,:,None,None]
        em_noise = torch.cat([em_s, torch.randn(em_s.size(0),100,1,1).cuda()], 1)
        # return: (embedding, fake image)
        return embedding, self.generator(em_noise)

In [None]:
class imageCritic(nn.Module):
    def __init__(self):
        super().__init__()
        critic = basic_critic(in_size=in_size, n_channels=3)
        self.body = nn.Sequential(*list(critic.children())[:-2])
        self.head = nn.Sequential(conv2d(640, 1, 4, padding=0), 
                                  AvgFlatten())
        self.squeezer = squeezer(400, 128)
        
    def forward(self, embedding, image):
        x = self.body(image)                     # (512,4,4)
        em_s = self.squeezer(embedding.view(embedding.size(0), -1)) 
        em_s = em_s[:,:,None,None]               # (128,1,1)
        em_s = em_s.repeat(1,1,4,4)              # (128,4,4)
        x = torch.cat([x, em_s], 1)              # (640,4,4)
        x = self.head(x)
        return x

In [None]:
class Loss(GANModule):    
    def __init__(self, gan_model):
        super().__init__()
        self.gan_model = gan_model

    def generator(self, output, *target):
        # output: (embedding, image)
        # target: (embedding, image)
        fake_pred = self.gan_model.critic(*output)        
        return fake_pred.mean()

    def critic(self, real_pred, embedding, wrong_img=None):
        # real_pred: (1,)
        fake = self.gan_model.generator(embedding.requires_grad_(False))
        # fake: (embedding, fake image)
        fake[1].requires_grad_(True)
        fake_pred = self.gan_model.critic(*fake)  
        #wrong_pred = self.gan_model.critic(fake[0], wrong_img)
        return real_pred.mean() - fake_pred.mean() #- wrong_pred.mean()

In [None]:
class Txt2ImgGANTrainer(GANTrainer):    
    def on_backward_begin(self, last_loss, last_output, **kwargs):        
        last_loss = last_loss.detach().cpu()
        if self.gen_mode:
            self.smoothenerG.add_value(last_loss)
            self.glosses.append(self.smoothenerG.smooth)
            # last_output: (embedding, image)
            self.last_gen = last_output[1].detach().cpu()
        else:
            self.smoothenerC.add_value(last_loss)
            self.closses.append(self.smoothenerC.smooth)

In [None]:
generator = imageGenerator()
critic = imageCritic()