In [1]:
import os
import torch
import zipfile
import kagglehub
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from torch.nn import CosineEmbeddingLoss
from torch.nn.functional import normalize
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import CLIPVisionModel, RobertaModel, AutoTokenizer, CLIPFeatureExtractor

In [2]:
!pip install -q git+https://github.com/sajjjadayobi/clipfa.git

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for clipfa (setup.py) ... [?25l[?25hdone


In [3]:
!kaggle datasets download -d aladdinpersson/flickr8kimagescaptions


Dataset URL: https://www.kaggle.com/datasets/aladdinpersson/flickr8kimagescaptions
License(s): unknown
Downloading flickr8kimagescaptions.zip to /kaggle/working
 97%|█████████████████████████████████████▉ | 1.01G/1.04G [00:03<00:00, 346MB/s]
100%|███████████████████████████████████████| 1.04G/1.04G [00:03<00:00, 338MB/s]


In [4]:
with zipfile.ZipFile("flickr8kimagescaptions.zip", "r") as zip_ref:
    zip_ref.extractall("flickr8k")


In [5]:
csv_file = "/kaggle/input/dataset/Captions.csv"  
image_dir = "/kaggle/input/flickr8kimagescaptions/flickr8k/images"

# Read Dataset

In [6]:
df = pd.read_csv(csv_file, delimiter=",", names=["Filename", "Caption"], skiprows=1)

df['Filename'] = df['Filename'].astype(str).str.strip()
df['Caption'] = df['Caption'].astype(str).str.strip()

image_files = set(os.listdir(image_dir))

df = df[df["Filename"].isin(image_files)]

if df.empty:
    print("❌ No valid images found in the dataset!")
else:
    print(f"✅ Found {len(df)} valid images with captions.")

train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f"📂 Training: {len(train_df)} images")
print(f"📂 Validation: {len(val_df)} images")
print(f"📂 Testing: {len(test_df)} images")

✅ Found 16639 valid images with captions.
📂 Training: 11647 images
📂 Validation: 2496 images
📂 Testing: 2496 images


In [7]:
torch.cuda.empty_cache()
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
vision_encoder = CLIPVisionModel.from_pretrained('SajjadAyoubi/clip-fa-vision').to(device)
preprocessor = CLIPFeatureExtractor.from_pretrained('SajjadAyoubi/clip-fa-vision')
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text').to(device)

torch.autograd.set_detect_anomaly(True)


config.json:   0%|          | 0.00/4.37k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/350M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]



tokenizer_config.json:   0%|          | 0.00/354 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/875k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.12M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/728 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/473M [00:00<?, ?B/s]

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7824b12963e0>

In [9]:
class ClipDataset(Dataset):
    def __init__(self, df, image_dir, tokenizer, preprocessor):
        self.df = df
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.preprocessor = preprocessor
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['Filename']
        caption = self.df.iloc[idx]['Caption']
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        
        text_inputs = self.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=32)  
        image_inputs = self.preprocessor(image, return_tensors="pt")
        
        return img_name, caption, text_inputs, image_inputs

# Train model

In [10]:
def collate_fn(batch):
    text_inputs = []
    image_inputs = []
    image_names = []
    captions = []
    
    for img_name, caption, text_input, image_input in batch:
        text_inputs.append({key: val.squeeze(0) for key, val in text_input.items()})
        image_inputs.append({key: val.squeeze(0) for key, val in image_input.items()})
        image_names.append(img_name)  
        captions.append(caption)  
    text_inputs = {key: torch.stack([x[key] for x in text_inputs]).to(device) for key in text_inputs[0].keys()}
    image_inputs = {key: torch.stack([x[key] for x in image_inputs]).to(device) for key in image_inputs[0].keys()}
    
    return image_names, captions, text_inputs, image_inputs

In [11]:
train_dataset = ClipDataset(train_df, image_dir, tokenizer, preprocessor)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)  

In [12]:
optimizer = torch.optim.Adam(list(vision_encoder.parameters()) + list(text_encoder.parameters()), lr=1e-6)  # کاهش learning rate
cosine_loss = CosineEmbeddingLoss()

In [13]:
for epoch in range(2):  
    vision_encoder.train()
    text_encoder.train()
    total_loss = 0
    
    for batch_idx, (image_names, captions, text_inputs, image_inputs) in enumerate(train_dataloader):
        print(f"Epoch {epoch+1}, Batch {batch_idx}")
        
        text_inputs = {key: val.to(device) for key, val in text_inputs.items()}
        image_inputs = {key: val.to(device) for key, val in image_inputs.items()}
        
        if torch.isnan(text_inputs['input_ids']).any() or torch.isnan(image_inputs['pixel_values']).any():
            print(f"NaN found in batch {batch_idx}")
            continue
        if torch.isinf(text_inputs['input_ids']).any() or torch.isinf(image_inputs['pixel_values']).any():
            print(f"Inf found in batch {batch_idx}")
            continue
        
        text_embedding = text_encoder(**text_inputs).last_hidden_state.mean(dim=1)
        image_embedding = vision_encoder(**image_inputs).last_hidden_state.mean(dim=1)
        
        if torch.isnan(text_embedding).any() or torch.isnan(image_embedding).any():
            print(f"NaN found in embeddings at batch {batch_idx}")
            continue
        
        text_embedding = normalize(text_embedding, p=2, dim=1)
        image_embedding = normalize(image_embedding, p=2, dim=1)
        
        target = torch.ones(text_embedding.size(0)).to(device)  
        loss = cosine_loss(text_embedding, image_embedding, target)
        
        print(f"  Loss for batch {batch_idx}: {loss.item()}")
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_dataloader)
    print(f"📌 Epoch {epoch+1}, Average Loss: {avg_loss}")
    
    torch.save(vision_encoder, f'vision_encoder_epoch_{epoch+1}.pt')
    torch.save(text_encoder, f'text_encoder_epoch_{epoch+1}.pt')

    # آزادسازی حافظه
    torch.cuda.empty_cache()

# ذخیره نهایی مدل‌ها بعد از همه ایپاک‌ها
torch.save(vision_encoder, 'final_vision_encoder.pt')
torch.save(text_encoder, 'final_text_encoder.pt')

print("✅ Models have been saved!")


Epoch 1, Batch 0
  Loss for batch 0: 1.0146584510803223
Epoch 1, Batch 1
  Loss for batch 1: 1.0231678485870361
Epoch 1, Batch 2
  Loss for batch 2: 1.025040626525879
Epoch 1, Batch 3
  Loss for batch 3: 1.0298871994018555
Epoch 1, Batch 4
  Loss for batch 4: 0.9986302852630615
Epoch 1, Batch 5
  Loss for batch 5: 0.9963793754577637
Epoch 1, Batch 6
  Loss for batch 6: 1.0114467144012451
Epoch 1, Batch 7
  Loss for batch 7: 0.9936645030975342
Epoch 1, Batch 8
  Loss for batch 8: 0.9774191379547119
Epoch 1, Batch 9
  Loss for batch 9: 0.9878183603286743
Epoch 1, Batch 10
  Loss for batch 10: 0.9919334650039673
Epoch 1, Batch 11
  Loss for batch 11: 1.0028587579727173
Epoch 1, Batch 12
  Loss for batch 12: 1.0045597553253174
Epoch 1, Batch 13
  Loss for batch 13: 0.9947875738143921
Epoch 1, Batch 14
  Loss for batch 14: 0.9981199502944946
Epoch 1, Batch 15
  Loss for batch 15: 0.9915082454681396
Epoch 1, Batch 16
  Loss for batch 16: 0.9806318283081055
Epoch 1, Batch 17
  Loss for batch 