# Image Captioning project


In [41]:
import requests
import os
import zipfile
import shutil
import json
from datasets import load_dataset

# Class to download the dataset from the given URL and create the dataset

class CreateImageDataset(): 
    def __init__(self, url, destination_folder, images_folder):
      self.url = url
      self.destination_folder = destination_folder
      self.images_folder = images_folder
    
    def downloader(self):
        repo_url = self.url
        destination = self.destination_folder
        # Check if the folder already exists

        if not os.path.exists(destination):
            os.makedirs(destination)

        response = requests.get(repo_url, allow_redirects=True)

        if response.status_code == 200:
            # Save the archive with a descriptive filename
            filename = f"{destination}/{repo_url.split('/')[-1]}"
            with open(filename, 'wb') as f:
                f.write(response.content)

            # Extract the archive directly within the destination folder (avoiding temporary directories)
            try:
                with zipfile.ZipFile(filename, 'r') as zip_ref:
                    zip_ref.extractall(destination)
                print("Folder downloaded and extracted to:", destination)
            except zipfile.BadZipFile:
                print(f"Error: Downloaded file {filename} is not a valid ZIP archive.")
            except Exception as e:
                print(f"Failed to unzip archive: {e}")

        else:
            print(f"Failed to download archive. Status code: {response.status_code}")


    def create_dictionary(self):
        captions = []
        for root, _, files in os.walk(self.destination_folder):
            for filename in files:
                if filename.endswith('_Description.txt'):
                    vegetable_type = os.path.splitext(filename)[0].replace('_Description', '')
                    image_file = os.path.join(root, f"{vegetable_type}_Iconic.jpg")

                    # Combine checks and handle both file existence and destination copy in one step
                    if os.path.isfile(image_file) and not os.path.exists(os.path.join(self.images_folder, os.path.basename(image_file))):
                        # Create destination folder if necessary
                        os.makedirs(self.images_folder, exist_ok=True)
                        shutil.copy(image_file, self.images_folder)

                    with open(os.path.join(root, filename), 'r') as f:
                        description_text = f.read()

                    captions.append({
                        "file_name": os.path.basename(image_file),
                        "text": description_text
                    })

        return captions
    

    def create_dataset(self, captions):
        with open(self.images_folder + "/metadata.jsonl", 'w') as f:
            for item in captions:
                f.write(json.dumps(item) + "\n")
        
        dataset = load_dataset("imagefolder", data_dir=self.images_folder, split="train")
        print("Dataset created successfully")
        print("Dataset info: ", dataset)
        return dataset


In [42]:
repo_url = "https://github.com/marcusklasson/GroceryStoreDataset/archive/refs/heads/master.zip"
destination_folder = "./data"
images_folder = "./grocery_store_images"

create_image_dataset = CreateImageDataset(repo_url, destination_folder, images_folder)
create_image_dataset.downloader()
captions = create_image_dataset.create_dictionary()

Folder downloaded and extracted to: ./data


In [44]:
dataset = create_image_dataset.create_dataset(captions)
print(dataset)

Generating train split: 81 examples [00:00, 7998.74 examples/s]

Dataset created successfully
Dataset info:  DatasetInfo(description='', citation='', homepage='', license='', features={'image': Image(decode=True, id=None), 'text': Value(dtype='string', id=None)}, post_processed=None, supervised_keys=None, task_templates=None, builder_name='imagefolder', dataset_name='imagefolder', config_name='default', version=0.0.0, splits={'train': SplitInfo(name='train', num_bytes=24720, num_examples=81, shard_lengths=None, dataset_name='imagefolder')}, download_checksums={'/home/ziyangfu/Code/ImageCaptioning/grocery_store_images/Alpro-Blueberry-Soyghurt_Iconic.jpg': {'num_bytes': 8512, 'checksum': None}, '/home/ziyangfu/Code/ImageCaptioning/grocery_store_images/Alpro-Fresh-Soy-Milk_Iconic.jpg': {'num_bytes': 10260, 'checksum': None}, '/home/ziyangfu/Code/ImageCaptioning/grocery_store_images/Alpro-Shelf-Soy-Milk_Iconic.jpg': {'num_bytes': 11247, 'checksum': None}, '/home/ziyangfu/Code/ImageCaptioning/grocery_store_images/Alpro-Vanilla-Soyghurt_Iconic.jpg': {'num




In [45]:
from torch.utils.data import Dataset

class ImageCaptioningDataset(Dataset):
  def __init__(self, dataset, processor):
      self.dataset = dataset
      self.processor = processor

  def __len__(self):
      return len(self.dataset)

  def __getitem__(self, idx):
      item = self.dataset[idx]

      encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt")

      # remove batch dimension
      encoding = {k:v.squeeze() for k,v in encoding.items()}

      return encoding

In [48]:
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("microsoft/git-base")

train_dataset = ImageCaptioningDataset(dataset, processor)

In [64]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=3)

In [60]:
import torch


class Transformer_model: 
    def __init__(self, model):
        self.model = model.to("cpu") 
        self.device = self.model.device 
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-5)

    def model_train(self, epochs, train_dataloader): 
        for epoch in range(epochs): # Use the specified number of epochs
            print("Epoch:", epoch)
            for idx, batch in enumerate(train_dataloader):
                input_ids = batch.pop("input_ids").to(self.device)
                pixel_values = batch.pop("pixel_values").to(self.device)

                outputs = self.model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)

                loss = outputs.loss
                # print("Loss:", loss.item())
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

        return self.model 

    @staticmethod
    def model_inference(image, processor, model): 
        inputs = processor(images=image, return_tensors="pt").to(model.device) 
        pixel_values = inputs.pixel_values
        generated_ids = model.generate(pixel_values=pixel_values)
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        print(generated_caption)


In [61]:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

In [63]:
Transformer = Transformer_model(model)
model = Transformer.model_train(20, train_dataloader)

Unused or unrecognized kwargs: padding.


Epoch: 0


KeyboardInterrupt: 