In [13]:
import torch
import clip
import torch.nn as nn

device = "mps"

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import clip
from PIL import Image
import pandas as pd
import os

# Dataset class to load data
class MeasuresDataset(Dataset):
    def __init__(self, csv_file, image_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        page = str(int(self.data.iloc[idx]['Page']))
        label = self.data.iloc[idx]['Final Score']
        text_description = self.data.iloc[idx]['Description']

        img_path = os.path.join(self.image_dir, f"{page}.jpg")
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)
        
        text_description = clip.tokenize([text_description]).squeeze(0)

        return image, text_description, torch.tensor(label, dtype=torch.float)

# Preprocessing and dataset loading
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = MeasuresDataset(csv_file='measures_context.csv', image_dir='./combined', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [8]:
def reward_function(predicted_score, target_score):
    return torch.exp(-torch.abs(predicted_score - target_score))

In [25]:
class GFlowNetAgent(nn.Module):
    def __init__(self, clip_model):
        super(GFlowNetAgent, self).__init__()
        self.clip_model = clip_model
        self.fc = nn.Linear(clip_model.visual.output_dim * 2, 3*224*224)

    def forward(self, image, text):
        with torch.no_grad():
            image_features = self.clip_model.encode_image(image)
            text_features = self.clip_model.encode_text(text)

        combined_features = torch.cat((image_features, text_features), dim=1)
        print(combined_features.shape)
        action = self.fc(combined_features)
        
        return action

In [26]:
def train_gflownet(agent, optimizer, dataloader, model, num_epochs=5):
    for epoch in range(num_epochs):
        agent.train()
        epoch_loss = 0
        for images, texts, labels in dataloader:
            images, texts, labels = images.to(device), texts.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            action = agent(images, texts)
            print(action.shape)
            generated_image = images + action.view(images.shape)
            
            predicted_score = model(generated_image, texts)
            reward = reward_function(predicted_score.squeeze(), labels)
            
            loss = -reward.mean()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], GFlowNet Loss: {avg_loss:.4f}")

In [27]:
def generate_image_gflownet(agent, text, target_score, num_steps=10):
    state = torch.randn(1, 3, 224, 224).to(device)
    text = clip.tokenize([text]).to(device)
    
    for _ in range(num_steps):
        action = agent(state, text)
        state += action.view(state.shape)
    
    return state

In [28]:
# Load CLIP model
clip_model, preprocess = clip.load("ViT-B/32")
clip_model = clip_model.to(device)

# Initialize agent and optimizer
agent = GFlowNetAgent(clip_model).to(device)
optimizer = optim.Adam(agent.parameters(), lr=1e-3)

# Train the GFlowNet model
train_gflownet(agent, optimizer, dataloader, model=None, num_epochs=5)

# Generate an image for a specific target score
description = "An image description."
target_score = 0.7
generated_image = generate_image_gflownet(agent, description, target_score)

torch.Size([32, 1024])
torch.Size([32, 150528])


TypeError: 'NoneType' object is not callable