In [2]:

import glob
import os
import random

import PIL
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import ruclip
from rudalle import get_rudalle_model, get_vae, get_tokenizer, get_realesrgan
from rudalle.pipelines import generate_images, show, cherry_pick_by_ruclip, super_resolution
from rudalle.utils import seed_everything
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.transforms as T
from tqdm import tqdm
from transformers import AdamW
from translatepy import Translate
import wandb

In [3]:
class PathConfiguration(object):
    def __init__(self):
        self.rudalle_cache_dir = '/data/workspace/rudalle'
        self.checkpoint_dir = '/data/workspace/checkpoints'
        self.training_image_path = '/data/workspace/images'
        self.data_path = '/data/workspace/data_desc.csv'

In [4]:
class TrainingConfiguration(object):
    def __init__(self, model_instance):
        self.model = model_instance
        self.model_name = 'tuned_model'
        self.text_seq_length = self.model.get_param('text_seq_length')
        self.total_seq_length = self.model.get_param('total_seq_length')
        self.save_every = 200
        self.prefix_length = 10
        self.bs = 1
        self.clip = 0.24
        self.lr = 1e-4
        self.warmup_steps = 50
        self.epochs = 10
        self.wandb = False

In [5]:
class RuDalleDataset(Dataset):
    clip_filter_thr = 0.24
    def __init__(
            self,
            file_path,
            csv_path,
            tokenizer,
            resize_ratio=0.75,
            shuffle=True,
            load_first=None,
            caption_score_thr=0.6
    ):
       
        self.text_seq_length = model.get_param('text_seq_length')
        self.tokenizer = tokenizer
        self.target_image_size = 256
        self.image_size=256
        self.samples = []


        self.image_transform = T.Compose([
                T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
                T.RandomResizedCrop(self.image_size,
                                    scale=(1., 1.), # в train было scale=(0.75., 1.),
                                    ratio=(1., 1.)),
                T.ToTensor()
            ])
        
        df = pd.read_csv(csv_path)
        for caption, f_path in zip(df['caption'], df['name']):
            if os.path.isfile(f'{file_path}/{f_path}'):
                #Note: You may want to perform a translation here on the caption... I don't see a difference
              self.samples.append([file_path, f_path, caption])
        if shuffle:
            np.random.shuffle(self.samples)
            print('Shuffled')
    
    def __len__(self):
        return len(self.samples)

    def load_image(self, file_path, img_name):
        image = PIL.Image.open(f'{file_path}/{img_name}')
        return image

    def __getitem__(self, item):
        item = item % len(self.samples)  # infinite loop, modulo dataset size
        file_path, img_name, text = self.samples[item]
        try:
          image = self.load_image(file_path, img_name)
          image = self.image_transform(image).to(device)
        except Exception as err:  # noqa
            print(err)
            random_item = random.randint(0, len(self.samples) - 1)
            return self.__getitem__(random_item)
        text =  tokenizer.encode_text(text, text_seq_length=self.text_seq_length).squeeze(0).to(device)
        return text, image

In [6]:
path_configuration: PathConfiguration = PathConfiguration()

translation_engine = Translate()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device, cache_dir=path_configuration.rudalle_cache_dir)

training_configuration: TrainingConfiguration = TrainingConfiguration(model)

vae = get_vae(dwt=True).to(device)

model_path = os.path.join(path_configuration.checkpoint_dir, f"{training_configuration.model_name}_dalle_last.pt")

if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))

tokenizer = get_tokenizer()

  0%|          | 0/398 [00:00<?, ?it/s]

◼️ Malevich is 1.3 billion params model from the family GPT3-like, that uses Russian language and text+image multi-modality.
Working with z of shape (1, 256, 32, 32) = 262144 dimensions.
vae --> ready
tokenizer --> ready


In [7]:
input_files = [os.path.join(path_configuration.training_image_path, item) for item in os.listdir(path_configuration.training_image_path)]

with open(path_configuration.data_path, 'w',encoding="utf-8") as f:
    header = "caption,name\n"
    f.write(header)
    for elem in input_files:
        foo = os.path.split(elem)[-1]
        generic = 'A red head woman'
        translated = translation_engine.translate(generic, source_language='EN', destination_language="ru").result
        f.write(f"{generic},{foo}\n")
        f.write(f"{translated},{foo}\n")

# Training
st = RuDalleDataset(file_path=path_configuration.training_image_path, csv_path=path_configuration.data_path, tokenizer=tokenizer)

training_configuration.wandb = False

train_dataloader = DataLoader(st, batch_size=training_configuration.bs, shuffle=True, drop_last=True)

model.train()

optimizer = AdamW(model.parameters(), lr = training_configuration.lr)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=training_configuration.lr,
    final_div_factor=500,
    steps_per_epoch=len(train_dataloader), 
    epochs=training_configuration.epochs)

Shuffled


In [8]:
def freeze(
    model,
    freeze_emb=True,
    freeze_ln=False,
    freeze_attn=False,
    freeze_ff=True,
    freeze_other=True,
):
    for name, p in model.module.named_parameters():
        name = name.lower()
        if 'ln' in name or 'norm' in name:
            p.requires_grad = not freeze_ln
        elif 'embeddings' in name:
            p.requires_grad = not freeze_emb
        elif 'mlp' in name:
            p.requires_grad = not freeze_ff
        elif 'attn' in name:
            p.requires_grad = not freeze_attn
        else:
            p.requires_grad = not freeze_other
    return model

In [9]:
def train(model, training_args: TrainingConfiguration, path_args: PathConfiguration, train_dataloader: RuDalleDataset):
  loss_logs = []
  try:
    progress = tqdm(total=training_args.epochs * len(train_dataloader), desc='finetuning goes brrr')
    save_counter = 0
    for epoch in range(training_args.epochs):
      
      for text, images in train_dataloader:
        device = model.get_param('device')
        save_counter+=1
        model.zero_grad()
        attention_mask = torch.tril(torch.ones((training_args.bs, 1, training_args.total_seq_length, training_args.total_seq_length), device=device))
        image_input_ids = vae.get_codebook_indices(images)
        
        input_ids = torch.cat((text, image_input_ids), dim=1) 
        loss, loss_values = model.forward(input_ids, attention_mask, return_loss=True)
        #train step
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(),training_args.clip)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        #save every here
        if save_counter % training_args.save_every == 0:
          print(f'Saveing checkpoint here {training_args.model_name}_dalle_{save_counter}.pt')
          
          plt.plot(loss_logs)
          plt.show()
          torch.save(
                    model.state_dict(),
                    os.path.join(path_args.checkpoint_dir, f"{training_args.model_name}_dalle_{save_counter}.pt")
                    )
        if training_configuration.wandb:
          wandb.log({"loss":  loss.item()})
        loss_logs+=[loss.item()]
        progress.update()
        progress.set_postfix({"loss": loss.item()})
    
    print(f'Completed tuning and saved to: {training_configuration.model_name}_dalle_last.pt')
    
    plt.plot(loss_logs)
    plt.show()
    
    torch.save(
                model.state_dict(),
                os.path.join(path_configuration.checkpoint_dir, f"{path_configuration.checkpoint_dir}/{training_configuration.model_name}_dalle_last.pt"))
  
  except KeyboardInterrupt:
    print(f'What for did you stopped? Please change model_path to /{path_configuration.checkpoint_dir}/{training_args.model_name}_dalle_Failed_train.pt')
    plt.plot(loss_logs)
    plt.show()
    
    torch.save(
                model.state_dict(),
                os.path.join(path_args.checkpoint_dir,f"{path_args.checkpoint_dir}/{training_args.model_name}_dalle_Failed_train.pt"))
  except Exception as err:
    print(f'Failed with {err}')

In [10]:
train(model, training_configuration, path_configuration, train_dataloader)

model = freeze(model = model,
    freeze_emb=False,
    freeze_ln=False,
    freeze_attn=True,
    freeze_ff=True,
    freeze_other=False)


  row_ids = torch.arange(past_length, input_shape[-1] + past_length,

finetuning goes brrr:   0%|          | 1/1580 [00:17<7:45:53, 17.70s/it][A
finetuning goes brrr:   0%|          | 1/1580 [00:19<8:31:58, 19.45s/it, loss=5.66][A

Failed with CUDA out of memory. Tried to allocate 32.00 MiB (GPU 0; 11.18 GiB total capacity; 10.13 GiB already allocated; 7.12 MiB free; 10.86 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF



