In [None]:
!nvidia-smi

In [None]:
#@title Load Kaggle API Token
from google.colab import files
files.upload()
!pip install -q kaggle
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download lifailifai/ucf101

# Install

In [None]:
!pip install rudalle==0.0.1rc4

In [None]:
!pip install transformers

In [None]:
!gdown https://drive.google.com/uc?id=1_B9Y2U9d-xIO1pKb4J9qsVltMCH_ZNOy

# Unzip data

In [None]:
#@markdown Lets download data
!unzip ucf101.zip

# Import

In [None]:
import io
import os
import PIL
import random
import numpy as np
import torch
import torchvision
import transformers
import more_itertools
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
from torch.utils.data import Dataset
from tqdm import tqdm
from dataclasses import dataclass, field
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import cv2
from PIL import Image
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip
from rudalle.utils import seed_everything

# Load ruDALLe

In [None]:
device = 'cuda'
model = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device)
vae = get_vae().to(device)
tokenizer = get_tokenizer()

# Args

In [None]:
class Args():
    def __init__(self):
     
        self.text_seq_length = model.get_param('text_seq_length')
        self.total_seq_length = model.get_param('total_seq_length')
        self.epochs = 1
        self.save_path='checkpoints/'
        self.model_name = 'awesomemodel_'
        self.save_every = 2000
        self.prefix_length = 10
        self.clip = 0.24
        self.lr = 2e-5
        self.warmup_steps =50
args = Args()
if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

# Dataset

In [None]:
def get_all_video_path(dir):
    paths = []
    for folder in os.listdir(dir):
        new_path = os.path.join(dir, folder)
        for file_name in os.listdir(new_path):
            paths.append(os.path.join(new_path, file_name))
    return paths

In [None]:
def read_video(path, transform=None, frames_num=None):
    frames = []
    cap = cv2.VideoCapture(path)
    while(cap.isOpened()):
        ret, frame = cap.read()
        
        if ret:
            frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            if transform is not None:
                frame = transform(frame)
            frames.append(frame)
            if frames_num is not None:
                if len(frames) >= frames_num:
                    break
        else:
            break
    cap.release()
    return torch.stack(frames)

In [None]:
class RuDalleVideoDataset(Dataset):
    def __init__(
            self,
            dir_path,
            csv_path,
            tokenizer,
    ):
        """ tokenizer - объект с методами tokenizer_wrapper.BaseTokenizerWrapper """
        self.df = pd.read_csv(csv_path)#'/content/drive/MyDrive/ucf100_ru.csv'
        self.df.index = self.df['en'].values
        self.paths = get_all_video_path(dir_path)
        self.text_seq_length = model.get_param('text_seq_length')
        self.tokenizer = tokenizer
        self.image_size = 128

        self.image_transform = T.Compose([
                T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
                T.RandomResizedCrop(self.image_size,
                                    scale=(1., 1.), # в train было scale=(0.75., 1.),
                                    ratio=(1., 1.)),
                T.ToTensor()
            ])

    
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, item):
        path = self.paths[item]
        ind_of_text = path.split('/')[3]
        try:
          video = read_video(path, self.image_transform, 50)
          if len(video) < 50:
              print('Video is shorter than 50 frames')
              self.__getitem__(random.randint(0, len(self.paths) - 1))
          #new_video = torch.stack([video[0], video[7], video[14], video[21]], dim=0)
          #new_video = torch.stack([video[0], video[3], video[7], video[11]], dim=0)
          new_video = torch.stack([video[0], video[16], video[32], video[48]], dim=0)
        except Exception as err:  # noqa
            print(err)
            return self.__getitem__(random.randint(0, len(self.paths) - 1))
        ru_text = self.df.loc[ind_of_text]['ru'].lower().capitalize()
        
        text = self.tokenizer.encode_text(ru_text, text_seq_length=self.text_seq_length).squeeze(0)
        return text, new_video

In [None]:
from torch.utils.data import Dataset, DataLoader
st = RuDalleVideoDataset(dir_path='/content/UCF-101', csv_path='ucf100_ru2.csv', tokenizer=tokenizer)
train_dataloader = DataLoader(st, batch_size=1, shuffle=True, drop_last=True)

# Train functions

In [None]:
def freeze(
    model,
    freeze_emb=True,
    freeze_ln=False,
    freeze_attn=False,
    freeze_ff=True,
    freeze_other=True,
):
    for name, p in model.module.named_parameters():
        name = name.lower()
        if 'ln' in name or 'norm' in name:
            p.requires_grad = not freeze_ln
        elif 'embeddings' in name:
            p.requires_grad = not freeze_emb
        elif 'mlp' in name:
            p.requires_grad = not freeze_ff
        elif 'attn' in name:
            p.requires_grad = not freeze_attn
        else:
            p.requires_grad = not freeze_other
    return model

In [None]:
#markdown Simple training loop
def train(model,args, train_dataloader, device):
    loss_logs = []
    progress = tqdm(total=len(train_dataloader), desc='finetuning')
    save_counter = 0
    device = model.get_param('device')
    for epoch in range(args.epochs):
      
      for text, images in train_dataloader:
        text = text.to(device)
        images = images.to(device)
        save_counter+=1

        model.zero_grad()
        attention_mask = torch.tril(torch.ones((1, 1, args.total_seq_length, args.total_seq_length), device=device))
        with torch.no_grad():
            image_input_ids = vae.get_codebook_indices(images[0]).flatten().unsqueeze(0)
        #print(text.shape, image_input_ids.shape)
        input_ids = torch.cat((text, image_input_ids), dim=1) 
        loss, loss_values = model.forward(input_ids, attention_mask, return_loss=True)
        #train step
        loss.backward()
            
        torch.nn.utils.clip_grad_norm_(model.parameters(),args.clip)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        #save every here
        if save_counter % args.save_every == 0:
          print(f'Saveing checkpoint here {args.model_name}_dalle_{save_counter}.pt')
          
          plt.plot(loss_logs)
          plt.show()
          torch.save(
                    model.state_dict(),
                    os.path.join(args.save_path,f"{args.model_name}_dalle_{save_counter}.pt")
                    )

        loss_logs+=[loss.item()]
        progress.update()
        progress.set_postfix({"loss": loss.item()})

    print(f'Complitly tuned and saved here  {args.model_name}__dalle_last.pt')
    
    plt.plot(loss_logs)
    plt.show()
    
    torch.save(
                model.state_dict(),
                'videodalle_new.pt'#os.path.join(args.save_path,f"{args.model_name}_dalle_last.pt")
                )

# Train

In [None]:
from transformers import  AdamW, get_linear_schedule_with_warmup
model.train()
optimizer = AdamW(model.parameters(), lr = args.lr)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, 
                                                final_div_factor=500,  
                                                steps_per_epoch=len(train_dataloader), epochs=args.epochs )

In [None]:
#@markdown You can unfreeze or freeze more parametrs, but it can 
model = freeze(model = model,
    freeze_emb=False,
    freeze_ln=False,
    freeze_attn=False,
    freeze_ff=True,
    freeze_other=False)#freeze params to 

train(model, args, train_dataloader, device)