In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.io import read_image
from torchvision.transforms import ToTensor,Compose
from accelerate.utils import LoggerType


import os
import pandas as pd
from torchvision.io import read_image
from PIL import Image

class CustomImageDataset(Dataset):


    def __init__(self,  img_dir):
        self.img_dir = img_dir

    def __len__(self):
        if(len(self.all_files) == 0):
            all_files = []
            for root, dirs, files in sorted(os.walk(self.img_dir)):
                for file in files:
                    all_files.append(root + '/' + file)
                    self.all_files = all_files
            print(str(len(self.all_files)) + " images found.")
        return len(self.all_files)


    def __getitem__(self, idx):
        trans = Compose([ToTensor()])
        if(len(self.all_files) == 0):
            all_files = []
            for root, dirs, files in sorted(os.walk(self.img_dir)):
                for file in files:
                    all_files.append(root + '/' + file)
                    self.all_files = all_files
        image = Image.open(self.all_files[idx]).convert('RGB')
        image = trans(image)
        
        return image

training_data = CustomImageDataset('/media/gamal/Passport/Datasets/VoxCeleb2Test/Voxceleb2TestFaces')



from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

In [None]:
from accelerate import Accelerator

accelerator = Accelerator(log_with="tensorboard",
                          project_dir='/media/gamal/Passport/muse/vqganvae_log')

In [None]:
import torch
from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer

import os
import configparser

# Loading configurations
configParser = configparser.RawConfigParser()   
configFilePath = r'configuration.txt'
configParser.read(configFilePath)
datasetPathVideo =  configParser.get('COMMON', 'datasetPathVideo')


vae = VQGanVAE(
    dim = 32,
    vq_codebook_dim = 8192,
    vq_codebook_size = 8192,
    channels=3,
    
)

vae.load('/media/gamal/Passport/muse/vqganvae/vae.12145000.pt')
# train on folder of images, as many images as possible


trainer = VQGanVAETrainer(
    current_step=1,
    num_epochs=50000,
    vae = vae,
    dataloader=train_dataloader,
    valid_dataloader=test_dataloader,
    gradient_accumulation_steps = 1,
    num_train_steps = 50000,
    results_dir = '/media/gamal/Passport/muse/vqganvae',
    save_results_every = 100,
    save_model_every = 100000,
    only_save_last_checkpoint=True,
    accelerator=accelerator,
).cuda()




trainer.train()

