In [23]:
import os
from PIL import Image
import numpy as np
from tqdm import tqdm

import torch 
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from transformers import ViTMAEForPreTraining, AutoImageProcessor
from transformers import TrainingArguments, Trainer
from transformers import ViTFeatureExtractor

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Image Processor

In [25]:
image_processor_checkpoint = r"facebook/vit-mae-base"
image_processor = AutoImageProcessor.from_pretrained(image_processor_checkpoint)

In [26]:
feature_extractor = ViTFeatureExtractor("facebook/vit-mae-base")



# Dataset

In [45]:
class BubblesDataset(Dataset):
    def __init__(self, images_dir, image_processor, feature_extractor, transform=None):
        self.images_dir = images_dir
        self.image_processor = image_processor
        self.feature_extractor = feature_extractor
        self.transform = transform
        
        self.image_list = os.listdir(self.images_dir)
        
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self, item):
        image = Image.open(os.path.join(self.images_dir, self.image_list[0]))
        to_tensor = transforms.ToTensor()
        image = to_tensor(image)
                
        if self.transform:
            image = self.transform(image)
            
        inputs = image_processor(images=image, return_tensors="pt")
        
        return inputs, image

In [46]:
transform = transforms.Compose(
    transforms.Resize((256, 256))
)

In [68]:
def collate_fn(examples):
    pixel_values = torch.stack([example.pixel_values for example in examples])
    return {"pixel_values": pixel_values}


In [69]:
train_images_dir = r"C:\Internship\ITMO ML\data\Frames\Bubbles every frame\F1_1_1_1.ts-frames"
val_images_dir = r"C:\Internship\ITMO ML\data\Frames\Bubbles every frame\F1_1_1_2.ts-frames"

train_dataset = BubblesDataset(train_images_dir, image_processor, feature_extractor)
val_dataset = BubblesDataset(val_images_dir, image_processor, feature_extractor)

# Model

In [70]:

model_checkpoint = "facebook/vit-mae-base"
model = ViTMAEForPreTraining.from_pretrained(model_checkpoint)

In [71]:
def compute_metrics():
    pass

In [72]:
training_args = TrainingArguments(
    output_dir = r'C:\Internship\ITMO ML\CTCI\notebooks\inpainting\output', 
    num_train_epochs = 3, 
    per_device_train_batch_size = 8, 
    per_device_eval_batch_size = 8, 
    # weight_decay =0.01, 
    logging_dir = r'C:\Internship\ITMO ML\CTCI\notebooks\inpainting\logs', 
    load_best_model_at_end = True, 
    learning_rate = 1e-5, 
    evaluation_strategy ='epoch', 
    logging_strategy = 'epoch', 
    save_strategy = 'epoch', 
    save_total_limit = 3,
    seed=239
)

In [73]:
trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
    data_collator=collate_fn
    # compute_metrics = compute_metrics
)

In [74]:
trainer.train()

TypeError: expected Tensor as element 0 in argument 0, but got BatchFeature