In [None]:
pip install pandas
!pip install flash_attn timm transformers
!pip install datasets

In [None]:
import torch
import os
import pandas as pd
torch.cuda.empty_cache()

import gc
gc.collect()

In [None]:
from transformers import AutoModelForCausalLM, AutoProcessor
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True, revision='refs/pr/6').to(device)

model = torch.nn.DataParallel(model)
model = model.to(device)

processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True, revision='refs/pr/6')
torch.cuda.empty_cache()

In [7]:
import pandas as pd
from PIL import Image
import os
import json
import uuid
from tqdm import tqdm
import pandas as pd
import pyarrow
from datasets import Dataset
from utils import download_images

train_csv_path = 'dataset/train.csv'
train_df = pd.read_csv(train_csv_path)

test_csv_path = 'dataset/test.csv'
test_df = pd.read_csv(test_csv_path)

download_images(train_df['image_link'], 'dataset/trains')
download_images(test_df['image_link'], 'dataset/tests')

In [8]:
import pandas as pd

target_distribution = {
    'height': 2000,
    'width': 2000,
    'depth': 2000,
    'item_weight': 2000,
    'maximum_weight_recommendation': 500,
    'wattage': 500,
    'voltage': 500,
    'item_volume': 500
}

# Create an empty DataFrame to store the balanced dataset
balanced_df = pd.DataFrame()

# Downsample each entity group to the specified number
for entity_name, target_count in target_distribution.items():
    # Filter the data for the current entity
    entity_data = train_df[train_df['entity_name'] == entity_name]
    
    # If the number of examples is greater than the target, downsample
    if len(entity_data) > target_count:
        entity_data = entity_data.sample(n=target_count, random_state=42)
    
    # Append the sampled data to the balanced DataFrame
    balanced_df = pd.concat([balanced_df, entity_data])

# Optionally, reset the index of the balanced DataFrame
balanced_df.reset_index(drop=True, inplace=True)

train_df = balanced_df
print(train_df['entity_name'].value_counts())


In [9]:
# Process train and test datasets
df_train = Dataset.from_pandas(train_df)
df_test = Dataset.from_pandas(test_df)

In [None]:
df_test

In [11]:
from torch.utils.data import Dataset
from PIL import Image
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

entity_unit_map = {
    'width': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'depth': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'height': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'item_weight': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'maximum_weight_recommendation': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'voltage': {'kilovolt', 'millivolt', 'volt'},
    'wattage': {'kilowatt', 'watt'},
    'item_volume': {'centilitre',
        'cubic foot',
        'cubic inch',
        'cup',
        'decilitre',
        'fluid ounce',
        'gallon',
        'imperial gallon',
        'litre',
        'microlitre',
        'millilitre',
        'pint',
        'quart'}
}


class DocVQADataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        my_list = entity_unit_map[example['entity_name']]
        result = ' '.join(my_list)  
        question = "<VQA> What is the " + example['entity_name'] + " of this entity ? answer should be in this list of units only "+ result
        answer = example['entity_value']
        filename = example['image_link'].split('/')[-1]
        image_filename = os.path.basename(filename)
        image_path = os.path.join("dataset/trains/", image_filename)

        if os.path.exists(image_path):
            image = Image.open(image_path)
        else:
            print(f"Image not found: {image_path}")
        if image.mode != "RGB":
            image = image.convert("RGB")
        return question, answer, image


In [None]:
model.load_state_dict(torch.load('florence model weight epoch 4.pth'))
model.eval() 

In [13]:
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (AdamW, AutoProcessor, get_scheduler)

def collate_fn(batch):
    questions, answers, images = zip(*batch)
    inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)
    return inputs, answers

# Create datasets

train_dataset = DocVQADataset(df_train)

# Create DataLoader
batch_size = 1
num_workers = 0

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True)

In [14]:
from tqdm import tqdm
import os
import torch
from torch.optim import AdamW
from transformers import get_scheduler

def train_model(train_loader, model, processor, epochs=10, lr=1e-6):
    optimizer = AdamW(model.parameters(), lr=lr)
    num_training_steps = epochs * len(train_loader)
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}", leave=True)

        for batch in progress_bar:
            inputs, answers = batch

            input_ids = inputs["input_ids"].to(device)
            pixel_values = inputs["pixel_values"].to(device)
            labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
            loss = outputs.loss.mean()

            # Backward pass
            loss.backward()

            # Update weights
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # Accumulate the loss
            train_loss += loss.item()

            # Update tqdm progress bar with loss
            progress_bar.set_postfix({"Loss": loss.item()})

        # Compute the average loss for the epoch
        avg_train_loss = train_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}] Average Training Loss: {avg_train_loss:.4f}")

        # Save model checkpoint
        output_dir = f"./model_checkpoints/epoch_{epoch+1}"
        os.makedirs(output_dir, exist_ok=True)
        model.module.save_pretrained(output_dir)  # Save model in a multi-GPU environment
        processor.save_pretrained(output_dir)

# Freezing vision tower parameters (you may skip if you want to fine-tune the whole model)
for param in model.module.vision_tower.parameters():  # Use model.module for DataParallel
    param.requires_grad = False

# Train the model
train_model(train_loader, model, processor, epochs=2)


In [15]:
def run_example(task_prompt, text_input, image):
    prompt = task_prompt + text_input

    # Ensure the image is in RGB mode
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Prepare inputs
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)

    # Forward pass through DataParallel model
    with torch.no_grad():  # No need to track gradients during inference
        # Access the underlying model from DataParallel
        model_for_generation = model.module if hasattr(model, 'module') else model

        # Generate output
        generated_ids = model_for_generation.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            num_beams=3
        )

    # Decode generated text
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    
    # Post-process the generated text
    parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
    
    return parsed_answer


In [None]:
df_test

In [None]:
import pandas as pd

# Existing list of lists
data = [
]

import os
from PIL import Image
from tqdm import tqdm
from IPython.display import display  # For displaying images in Jupyter environments

def process_examples(dataset, run_example_func):
    # Add tqdm to show a progress bar
    for idx in tqdm(range(len(dataset)), desc="Processing examples"):
        example = dataset[idx]  # Direct indexing for Dataset object
        # Ensure image_link is a key in your dataset
        filename = example['image_link'].split('/')[-1]
        image_filename = os.path.basename(filename)
        image_path = os.path.join("dataset/tests/", image_filename)

        if os.path.exists(image_path):
            image = Image.open(image_path)
        else:
            print(f"Image not found: {image_path}")
            continue  # Skip to the next iteration if the image is not found

        if image.mode != "RGB":
            image = image.convert("RGB")

        my_list = entity_unit_map[example['entity_name']]
        result = ' '.join(my_list)  

        #print(result)
        
        question = "<VQA> What is the " + example['entity_name'] + " of this entity ? answer should be in this list of units only "+ result
        answer = run_example_func("VQA", question, image)
        
        data.append([example['index'], answer['VQA']])

        if(idx%10000==0):
            
            # Convert the updated list to a DataFrame
            df = pd.DataFrame(data, columns=['index', 'entity_value'])
            
            # Save the DataFrame to a CSV file
            df.to_csv(f'index_value2_{idx}.csv', index=False)
                    
        # Display the image if you're in a Jupyter notebook or similar environment
        #display(image)  # Ensure 'display' is imported if not in Jupyter, you might use another method

# Call the function
process_examples(df_test, run_example)


# Convert the updated list to a DataFrame
df = pd.DataFrame(data, columns=['index', 'entity_value'])

# Save the DataFrame to a CSV file
df.to_csv('result.csv', index=False)

print("Updated CSV file 'result.csv' generated successfully")
