In [40]:
import torch
from datasets import load_dataset
import matplotlib.pyplot as plt

In [41]:
dataset = load_dataset("martinsinnona/ploty", split = "train")
#dataset_test = load_dataset("martinsinnona/ploty", split = "test")

In [42]:
dataset

Dataset({
    features: ['image', 'text'],
    num_rows: 45
})

In [43]:
from torch.utils.data import Dataset, DataLoader

MAX_PATCHES = 1024

class ImageCaptioningDataset(Dataset):

    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):

        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], return_tensors="pt", add_special_tokens=True, max_patches=MAX_PATCHES)

        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["text"] = item["text"]

        return encoding

In [44]:
from transformers import AutoProcessor, Pix2StructForConditionalGeneration

processor = AutoProcessor.from_pretrained("ybelkada/pix2struct-base")
model = Pix2StructForConditionalGeneration.from_pretrained("ybelkada/pix2struct-base")

In [45]:
def collator(batch):

  new_batch = {"flattened_patches":[], "attention_mask":[]}
  texts = [item["text"] for item in batch]

  text_inputs = processor(text=texts, padding="max_length", return_tensors="pt", add_special_tokens=True, max_length=40)

  new_batch["labels"] = text_inputs.input_ids

  for item in batch:
    new_batch["flattened_patches"].append(item["flattened_patches"])
    new_batch["attention_mask"].append(item["attention_mask"])

  new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
  new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

  return new_batch

In [46]:
train_dataset = ImageCaptioningDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2, collate_fn=collator)

In [47]:
seed = 14895215085708117999

In [None]:
EPOCHS = 5000

optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

#print(torch.seed())
torch.manual_seed(seed)

model.train()

losses = []
accuracies = []

for epoch in range(EPOCHS):

    print("Epoch:", epoch)
    for idx, batch in enumerate(train_dataloader):

        labels = batch.pop("labels").to(device)
        flattened_patches = batch.pop("flattened_patches").to(device)
        attention_mask = batch.pop("attention_mask").to(device)

        outputs = model(flattened_patches = flattened_patches,
                    attention_mask = attention_mask,
                    labels = labels)

        loss = outputs.loss
        print("Loss:", loss.item())

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        if (epoch + 1) % 20 == 0:

            model.eval()

            predictions = model.generate(flattened_patches = flattened_patches, attention_mask = attention_mask)
            print("Predictions:", processor.batch_decode(predictions, skip_special_tokens = True))

            model.train()
        
    if (epoch + 1) % 100 == 0:

        results = []

        for data in dataset:

            image = data["image"]

            model.eval()
            inputs = processor(images=image, return_tensors="pt", max_patches=512).to(device)

            flattened_patches = inputs.flattened_patches
            attention_mask = inputs.attention_mask

            generated_ids = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=50)
            generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

            results.append([data["text"], generated_caption])

        accuracy = [res[0] == res[1] for res in results].count(True) / len(dataset)
        accuracies.append(accuracy)
        
        print("\n\nAccuracy: " + str(accuracy) + "\n" + accuracies + "\n\n")

    losses.append(loss.cpu().detach().numpy().item())

Epoch: 0
Loss: 33.61376953125
Loss: 31.73067283630371
Loss: 29.96343421936035
Loss: 28.87127113342285
Loss: 27.68637466430664
Loss: 26.194772720336914
Loss: 29.694087982177734
Loss: 27.728179931640625
Loss: 25.795007705688477
Loss: 26.45102310180664
Loss: 25.57311248779297
Loss: 24.646150588989258
Loss: 26.316425323486328
Loss: 24.27987289428711
Loss: 24.230772018432617
Loss: 27.45828628540039
Loss: 22.946203231811523
Loss: 24.926530838012695
Loss: 23.482873916625977
Loss: 24.991756439208984
Loss: 21.299518585205078
Loss: 23.10517692565918
Loss: 21.44572639465332
Epoch: 1
Loss: 20.919490814208984
Loss: 23.240283966064453
Loss: 21.465736389160156
Loss: 22.05978012084961
Loss: 21.141347885131836
Loss: 20.253454208374023
Loss: 21.645145416259766
Loss: 22.65595817565918
Loss: 20.406471252441406
Loss: 21.041568756103516
Loss: 19.837553024291992
Loss: 19.621917724609375
Loss: 19.628734588623047
Loss: 22.18081283569336
Loss: 21.087142944335938
Loss: 19.81709098815918
Loss: 20.01675033569336
L



Predictions: ['<mark> point </mark><x> num1 </x><y> num2 </y>', '<mark> point </mark><x> num1 </x><y> num2 </y>']
Loss: 3.229623317718506
Predictions: ['<mark> point </mark><x> num1 </x><y> num1 </y>', '<mark> point </mark><x> num1 </x><y> num1 </y>']
Loss: 3.0967984199523926
Predictions: ['<mark> point </mark><x> num2 </x><y> num2 </y>', '<mark> point </mark><x> num9 </x><y> num9 </y>']
Loss: 3.06111216545105
Predictions: ['<mark> point </mark><x> num1 </x><y> num2 </y>', '<mark> point </mark><x> num1 </x><y> num1 </y>']
Loss: 2.972745895385742
Predictions: ['<mark> point </mark><x> num1 </x><y> num2 </y>', '<mark> point </mark><x> num1 </x><y> num2 </y>']
Loss: 2.840327501296997
Predictions: ['<mark> point </mark><x> num2 </x><y> num2 </y>', '<mark> point </mark><x> num0 </x><y> num0 </y>']
Loss: 3.062047004699707
Predictions: ['<mark> point </mark><x> num8 </x><y> num8 </y>', '<mark> point </mark><x> num3 </x><y> num3 </y>']
Loss: 3.147752285003662
Predictions: ['<mark> point </mark

In [None]:
plt.ylim(-0.01,1.1)

plt.plot(np.arange(0, len(losses), 100), accuracies, label = "accuracy")
plt.plot(losses, label = "loss")

plt.axhline(y = 1, xmin = 0, xmax = 800, color = "gray", linestyle = "dashed")

plt.legend()

In [None]:
results = []

for data in dataset:

    image = data["image"]

    model.eval()
    inputs = processor(images=image, return_tensors="pt", max_patches=512).to(device)

    flattened_patches = inputs.flattened_patches
    attention_mask = inputs.attention_mask

    generated_ids = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=50)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    results.append([data["text"], generated_caption])

accuracy = [res[0] == res[1] for res in results].count(True) / len(dataset)

print("Accuracy: " + str(accuracy))
results