How to use:
- In "JiuTian-LION\ram" add "pretrained/ram_swin_large_14m.pth"
- Run "process_dataset_ram.ipynb"

In [None]:
try:
    import torch_directml
    DEVICE = torch_directml.device()
    print("found AMD GPU")
except (ImportError, RuntimeError):
    DEVICE = 'cpu'
    print("using CPU")

DirectML device found. Using AMD GPU.


In [None]:
import pandas as pd
import torch
import requests
from PIL import Image
from io import BytesIO
import numpy as np
from tqdm import tqdm
import os
from models.ram import ram
from transform import get_transform

# --- Configuration ---
DATASET_PATH = 'coco_image_captions.csv'
# Path where the new dataset will be saved
OUTPUT_PATH = 'coco_image_captions_with_tags.csv'
# Path to the pretrained model weights
PRETRAINED_MODEL_PATH = 'pretrained/ram_swin_large_14m.pth'
# Set the image size required by the model
IMAGE_SIZE = 384
try:
    import torch_directml
    DEVICE = torch_directml.device()
    print(f"Successfully set device to DirectML: {DEVICE}")
except (ImportError, RuntimeError) as e:
    print(f"Could not initialize DirectML, falling back to CPU. Error: {e}")
    DEVICE = 'cpu'

# --- Model Loading ---
def load_ram_model(model_path, image_size):
    """Loads the Recognize Anything Model (RAM)."""
    os.environ['CONFIG_PATH'] = './models'
    
    model = ram(
        pretrained=model_path,
        image_size=image_size,
        vit='swin_l'
    )
    model.eval()
    model = model.to(DEVICE)
    print(f"RAM model loaded successfully on {DEVICE}.")
    return model

CUSTOM_THRESHOLD = 0.5
MAX_TAGS = 40

def generate_tags_with_scores(model, image, transform, max_tags, threshold):
    """
    Generates a rich but concise set of tags and scores using a 
    Top-K (max_tags) and minimum confidence floor (threshold) approach.
    """
    image_tensor = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        # Model Forward Pass to get confidence scores for all tags 
        image_embeds = model.image_proj(model.visual_encoder(image_tensor))
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(DEVICE)
        bs = image_embeds.shape[0]
        projected_label_embed = torch.nn.functional.relu(model.wordvec_proj(model.label_embed))
        label_embed = projected_label_embed.unsqueeze(0).repeat(bs, 1, 1)
        tagging_embed = model.tagging_head(
            encoder_embeds=label_embed,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=False,
            mode='tagging',
        )
        logits = model.fc(tagging_embed[0]).squeeze(-1)
        confidences = torch.sigmoid(logits).squeeze()
        
        # Sort scores and get the top K candidates to set a hard limit
        sorted_scores, sorted_indices = torch.sort(confidences, descending=True)
        top_k_scores = sorted_scores[:max_tags]
        top_k_indices = sorted_indices[:max_tags]

        # From these top K, filter out any that are below the confidence threshold
        final_indices = top_k_indices[top_k_scores > threshold]
        final_scores = top_k_scores[top_k_scores > threshold]
        
        # Convert to CPU and numpy for tag lookup
        final_indices = final_indices.cpu().numpy()
        final_scores = final_scores.cpu().numpy().tolist()

        # Exclude deleted tags from the final list
        valid_indices = [i for i in final_indices if i not in model.delete_tag_index]
        if not valid_indices:
             return "", []

        # Create a map of index-to-score to correctly get scores for valid tags
        score_map = {idx: score for idx, score in zip(final_indices, final_scores)}
        valid_scores = [score_map[i] for i in valid_indices]

        # Get the final tags
        predicted_tags = model.tag_list[valid_indices]
        
        tag_string = ' | '.join(predicted_tags)
        
        return tag_string, valid_scores


def download_image(url):
    """Downloads an image from a URL and returns it as a PIL Image."""
    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()  # Raise an exception for bad status codes
        image = Image.open(BytesIO(response.content))
        return image
    except requests.exceptions.RequestException as e:
        print(f"Error downloading {url}: {e}")
        return None

def process_dataset(df, model, transform):
    """
    Processes the DataFrame to add tags and confidence scores.
    """
    new_tags = []
    new_scores = []

    for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Processing images"):
        image_url = row['image_url']
        image = download_image(image_url)
        
        if image:
            # Pass both MAX_TAGS and CUSTOM_THRESHOLD to the function
            tags, scores = generate_tags_with_scores(model, image, transform, MAX_TAGS, CUSTOM_THRESHOLD)
            new_tags.append(tags)
            new_scores.append(scores)
        else:
            # Append empty values if image download fails
            new_tags.append("")
            new_scores.append([])
            
    df['tags'] = new_tags
    df['confidence_scores'] = new_scores
    return df


import time
if __name__ == "__main__":
    # Load the model and image transformer
    ram_model = load_ram_model(PRETRAINED_MODEL_PATH, IMAGE_SIZE)
    image_transform = get_transform(image_size=IMAGE_SIZE)

    # Load the dataset
    print(f"Loading dataset from {DATASET_PATH}...")
    coco_df = pd.read_csv(DATASET_PATH)
    
    # --- TESTING ---
    num_imgs_start = 75000
    num_imgs_end = 118287+1
    coco_df = coco_df[num_imgs_start:]
    # print(f"Testing with the first {num_imgs} images...")
    # ---------------------------------
    start_time = time.time()
    # Process the dataset to get tags and scores
    processed_df = process_dataset(coco_df, ram_model, image_transform)

    # 4. Save the new dataset
    print(f"Saving updated dataset to {OUTPUT_PATH}...")
    processed_df.to_csv(OUTPUT_PATH, index=False)
    print("Processing complete!")
    print(f"Total runtime: {time.time() - start_time:.2f} seconds")
    print(f"Time per image: {(time.time() - start_time)/(num_imgs_end-num_imgs_start):.2f} seconds")

    print(processed_df.head())




Successfully set device to DirectML: privateuseone:0


BertLMHeadModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From ðŸ‘‰v4.50ðŸ‘ˆ onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


/encoder/layer/0/crossattention/self/query is tied
/encoder/layer/0/crossattention/self/key is tied
/encoder/layer/0/crossattention/self/value is tied
/encoder/layer/0/crossattention/output/dense is tied
/encoder/layer/0/crossattention/output/LayerNorm is tied
/encoder/layer/0/intermediate/dense is tied
/encoder/layer/0/output/dense is tied
/encoder/layer/0/output/LayerNorm is tied
/encoder/layer/1/crossattention/self/query is tied
/encoder/layer/1/crossattention/self/key is tied
/encoder/layer/1/crossattention/self/value is tied
/encoder/layer/1/crossattention/output/dense is tied
/encoder/layer/1/crossattention/output/LayerNorm is tied
/encoder/layer/1/intermediate/dense is tied
/encoder/layer/1/output/dense is tied
/encoder/layer/1/output/LayerNorm is tied
--------------
pretrained/ram_swin_large_14m.pth
--------------
load checkpoint from pretrained/ram_swin_large_14m.pth
vit: swin_l
RAM model loaded successfully on privateuseone:0.
Loading dataset from coco_image_captions.csv...


Processing images: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 43287/43287 [4:32:01<00:00,  2.65it/s]   


Saving updated dataset to coco_image_captions_with_tags.csv...
Processing complete!
Total runtime: 16322.83 seconds
Time per image: 0.38 seconds
       image_id                                          image_url  \
75000    101017  http://images.cocodataset.org/train2017/000000...   
75001    516084  http://images.cocodataset.org/train2017/000000...   
75002     40596  http://images.cocodataset.org/train2017/000000...   
75003     68502  http://images.cocodataset.org/train2017/000000...   
75004    523262  http://images.cocodataset.org/train2017/000000...   

                                               caption_1  \
75000  Two adults participating in frisbee football w...   
75001  A man with a lasso riding a horse on water dur...   
75002        A polar bear scratching its back on a tree.   
75003  Herd of zebras and giraffes on the edge of a f...   
75004            Group posing for a photo on a ski hill.   

                                               caption_2  \
75000  A woma

In [51]:
hours = 7
data_count = hours * 60 * 60 / 0.35  
print(f"Data for {hours} hours is {round(data_count,0)}")
total_data = 118287
print(f"Remaining data to process: {round(total_data - data_count,0)}")
print(f"Time to process remaining data: {round((total_data - data_count) * 0.33 / 60 / 60,2)} hours")

Data for 7 hours is 72000.0
Remaining data to process: 46287.0
Time to process remaining data: 4.24 hours


In [1]:
num_imgs_start = 75000
num_imgs_end = 118287+1
time_per_image = 0.35
print(f"Time to run for {num_imgs_end-num_imgs_start} images is {round((num_imgs_end-num_imgs_start)*time_per_image/60/60,2)} hours")

Time to run for 43288 images is 4.21 hours


**In case of error**

In [None]:
import pandas as pd
import re

# Paths
DATASET_PATH = 'coco_image_captions.csv'
OUTPUT_PATH = 'coco_image_captions_with_tags.csv'

# Load the processed dataset (the one with missing tags)
processed_df = pd.read_csv(OUTPUT_PATH)

# Identify rows where tags are missing or empty
missing_df = processed_df[processed_df['tags'].isna() | (processed_df['tags'].str.strip() == '')]

print(f"Found {len(missing_df)} images with missing tags.")
missing_df.head()

def extract_image_id(url):
    match = re.search(r'/(\d+)\.jpg$', url)
    return int(match.group(1)) if match else None

missing_df['image_id'] = missing_df['image_url'].apply(extract_image_id)
missing_df[['image_url', 'image_id']].head()

from models.ram import ram
from transform import get_transform
import torch
from PIL import Image
from io import BytesIO
import requests
from tqdm import tqdm

# --- Config ---
PRETRAINED_MODEL_PATH = 'pretrained/ram_swin_large_14m.pth'
IMAGE_SIZE = 384
CUSTOM_THRESHOLD = 0.5
MAX_TAGS = 40

# --- Device setup ---
try:
    import torch_directml
    DEVICE = torch_directml.device()
except:
    DEVICE = 'cpu'

# --- Load RAM ---
def load_ram_model(model_path, image_size):
    import os
    os.environ['CONFIG_PATH'] = './models'
    model = ram(pretrained=model_path, image_size=image_size, vit='swin_l')
    model.eval().to(DEVICE)
    return model

ram_model = load_ram_model(PRETRAINED_MODEL_PATH, IMAGE_SIZE)
transform = get_transform(image_size=IMAGE_SIZE)

import time

def download_image_with_retry(url, retries=3, delay=2):
    for attempt in range(retries):
        try:
            response = requests.get(url, timeout=10)
            response.raise_for_status()
            return Image.open(BytesIO(response.content))
        except requests.exceptions.RequestException as e:
            print(f"Attempt {attempt+1} failed for {url}: {e}")
            time.sleep(delay)
    return None

new_tags = []
new_scores = []

for i, row in tqdm(missing_df.iterrows(), total=len(missing_df), desc="Reprocessing missing images"):
    image = download_image_with_retry(row['image_url'])
    if image:
        tags, scores = generate_tags_with_scores(ram_model, image, transform, MAX_TAGS, CUSTOM_THRESHOLD)
        new_tags.append(tags)
        new_scores.append(scores)
    else:
        new_tags.append("")
        new_scores.append([])

missing_df['tags'] = new_tags
missing_df['confidence_scores'] = new_scores

# Merge the fixed rows back into the main dataframe
updated_df = processed_df.copy()
updated_df.set_index('image_url', inplace=True)
missing_df.set_index('image_url', inplace=True)

# Update only those rows that were missing
updated_df.update(missing_df[['tags', 'confidence_scores']])

# Save final corrected dataset
updated_df.reset_index(inplace=True)
updated_df.to_csv(OUTPUT_PATH, index=False)

print("Dataset successfully updated with missing image tags.")

still_missing = updated_df[updated_df['tags'].isna() | (updated_df['tags'].str.strip() == '')]
print(f"Remaining missing tags: {len(still_missing)}")
