In [None]:
import torch
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
from huggingface_hub import login

In [None]:
#login(token = "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")

#ploty_dataset_train = load_dataset("martinsinnona/ploty", split = "train")
#ploty_dataset_test = load_dataset("martinsinnona/ploty", split = "test")

ploty_dataset_train = load_dataset("imagefolder", data_dir = "dataset", split = "train")
ploty_dataset_test = load_dataset("imagefolder", data_dir = "dataset", split = "test")

In [None]:
print(ploty_dataset_train, ploty_dataset_test)

In [None]:
from torch.utils.data import Dataset, DataLoader

MAX_PATCHES = 1024

class ImageCaptioningDataset(Dataset):

    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):

        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], text = "", return_tensors="pt", add_special_tokens=True, max_patches=MAX_PATCHES)

        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["text"] = item["text"]

        return encoding

In [None]:
from transformers import AutoProcessor, Pix2StructForConditionalGeneration

processor = AutoProcessor.from_pretrained("ybelkada/pix2struct-base")
model = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-base")

In [None]:
def collator(batch):

  new_batch = {"flattened_patches":[], "attention_mask":[]}
  texts = [item["text"] for item in batch]

  text_inputs = processor(text=texts, padding="max_length", return_tensors="pt", add_special_tokens=True, max_length=200)

  new_batch["labels"] = text_inputs.input_ids

  for item in batch:
    new_batch["flattened_patches"].append(item["flattened_patches"])
    new_batch["attention_mask"].append(item["attention_mask"])

  new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
  new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

  return new_batch

In [None]:
train_dataset = ImageCaptioningDataset(ploty_dataset_train, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2, collate_fn=collator)

In [None]:
def eval_model(dataset):
    
    results = []
    i = 0
    
    for data in dataset:
        
        if i % 10 == 0: print(i)
        i += 1
        
        image = data["image"]

        model.eval()
        inputs = processor(images=image, return_tensors="pt", max_patches=512).to(device)

        flattened_patches = inputs.flattened_patches
        attention_mask = inputs.attention_mask

        generated_ids = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=200)
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        results.append([data["text"], generated_caption])

    return np.mean([res[0] == res[1] for res in results])

In [None]:
seed = 14895215085708117999
torch.manual_seed(seed)

In [None]:
EPOCHS = 100
eval_step = 50

optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.train()

losses = []
accuracies_train = []
accuracies_test = []

for epoch in range(EPOCHS + 1):

    print("Epoch:", epoch)
    for idx, batch in enumerate(train_dataloader):

        labels = batch.pop("labels").to(device)
        flattened_patches = batch.pop("flattened_patches").to(device)
        attention_mask = batch.pop("attention_mask").to(device)

        outputs = model(flattened_patches = flattened_patches,
                    attention_mask = attention_mask,
                    labels = labels)

        loss = outputs.loss
        print("Loss:", loss.item())

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        if (epoch + 1) % 50 == 0:

            model.eval()

            predictions = model.generate(flattened_patches = flattened_patches, attention_mask = attention_mask)
            print("Predictions:", processor.batch_decode(predictions, skip_special_tokens = True))

            model.train()
        
    if epoch % eval_step == 0:

        #accuracy_train = eval_model(ploty_dataset_train)
        accuracy_train = 0
        accuracy_test = eval_model(ploty_dataset_test)

        accuracies_train.append(accuracy_train)
        accuracies_test.append(accuracy_test)
        
        print("\n\nAccuracy (train): " + str(accuracy_train) + "\n")
        print("Accuracy (test): " + str(accuracy_test) + "\n")
        print(accuracies_train)
        print(accuracies_test)
        print("\n")

    losses.append(loss.cpu().detach().numpy().item())

In [None]:
model.push_to_hub("modelD")

In [None]:
plt.ylim(-0.01,1.1)
plt.yticks(np.linspace(0,1,21))
plt.tick_params(axis='y', labelsize = 8)
plt.grid(axis = 'y', linewidth = 0.5)
plt.xlabel("epochs")
plt.ylabel("accuracies")

plt.plot(np.arange(0, len(accuracies_train) * eval_step, eval_step), accuracies_train, label = "accuracy train")
plt.plot(losses, label = "loss")
plt.plot(np.arange(0, len(accuracies_test) * eval_step, eval_step), accuracies_test, label = "accuracy test")

plt.axhline(y = 1, xmin = 0, xmax = 800, color = "gray", linestyle = "dashed", alpha = 0.4)

plt.legend()

In [None]:
accuracies_train

In [None]:
results = []

for data in ploty_dataset_test:

    image = data["image"]

    model.eval()
    inputs = processor(images=image, return_tensors="pt", max_patches=512).to(device)

    flattened_patches = inputs.flattened_patches
    attention_mask = inputs.attention_mask

    generated_ids = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=50)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    results.append([data["text"], generated_caption])

accuracy = np.mean([res[0] == res[1] for res in results])

print("Accuracy: " + str(accuracy))
results

In [None]:
import time

start_time = time.time()

for i in range(100000):
    a = 1

end_time = time.time()
elapsed_time = end_time - start_time