In [1]:
import os
import pandas as pd
import torch
from torch import nn, optim
from torchvision import models, transforms
from PIL import Image
import requests
from io import BytesIO
import re
from tqdm import tqdm
import src.constants

In [2]:
import re

# Define a function to extract value and unit from entity_value
def extract_value_and_unit(entity_value):
    if pd.isna(entity_value):  # Handle missing values
        return None, None
    match = re.match(r'([0-9.]+)\s*(\w+)', entity_value)
    if match:
        value = float(match.group(1))  # Extract the float value
        unit = match.group(2).strip().lower()  # Extract and normalize the unit
        return value, unit
    return None, None  # In case the value isn't in the expected format


In [3]:
# Preprocessing functions for images and data
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

##### Here is another preprocess to may improve model performance 

In [5]:
preprocess = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [6]:
def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert('RGB')

def extract_value_and_unit(entity_value):
    if pd.isna(entity_value):  # Handle missing values
        return None, None
    match = re.match(r'([0-9.]+)\s*(\w+)', entity_value)
    if match:
        value = float(match.group(1))  # Extract the float value
        unit = match.group(2).strip().lower()  # Extract and normalize the unit
        return value, unit
    return None, None  # In case the value isn't in the expected format

#### here is the updated EntityValuePredictor

In [7]:
class EntityValuePredictor(nn.Module):
    def __init__(self, num_entity_types):
        super(EntityValuePredictor, self).__init__()
        self.backbone = models.resnet50(pretrained=True)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.fc1 = nn.Linear(in_features, 1024)  
        self.fc2 = nn.Linear(1024, 512)         
        self.fc3 = nn.Linear(512, num_entity_types) 
        self.value_head = nn.Linear(512, 1)     
        
    def forward(self, x):
        x = self.backbone(x)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        entity_type_logits = self.fc3(x)
        value = self.value_head(x)
        return entity_type_logits, value


#### here is different train_model function for optimisation

In [8]:
import torch.optim as optim

# Updated training function
def train_model(num_epochs, model, train_loader, criterion, optimizer):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, labels, entity_types) in enumerate(tqdm(train_loader)):
            inputs, labels, entity_types = inputs.cuda(), labels.cuda(), entity_types.cuda()
            
            optimizer.zero_grad()
            entity_type_logits, predicted_value = model(inputs)
            
            loss_entity_type = criterion(entity_type_logits, entity_types)
            loss_value = criterion(predicted_value.squeeze(), labels)
            loss = loss_entity_type + loss_value
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")
    
    torch.save(model.state_dict(), 'entity_value_predictor.pth')
    print(f"Model saved at 'entity_value_predictor.pth'")




In [10]:
# Model Definition
# class EntityValuePredictor(nn.Module):
#     def __init__(self, num_entity_types):
#         super(EntityValuePredictor, self).__init__()
#         self.backbone = models.resnet50(pretrained=True)
#         in_features = self.backbone.fc.in_features
#         self.backbone.fc = nn.Identity()
#         self.fc1 = nn.Linear(in_features, 512)
#         self.fc2 = nn.Linear(512, num_entity_types)
#         self.value_head = nn.Linear(512, 1)  # Predict a single value
        
#     def forward(self, x):
#         x = self.backbone(x)
#         x = torch.relu(self.fc1(x))
#         entity_type_logits = self.fc2(x)
#         value = self.value_head(x)
#         return entity_type_logits, value

# Training Process
def train_model(train_data_path, model, criterion_entity, criterion_value, optimizer, entity_types, epochs=3):
    # Load the training data
    train_data = pd.read_csv(train_data_path)
    
    for epoch in range(epochs):
        running_loss = 0.0
        for index, row in tqdm(train_data.iterrows(), total=len(train_data)):
            try:
                # Get image and process it
                image = download_image(row['image_link'])
                image_tensor = preprocess(image)
                
                # Extract entity name, value, and unit
                entity_name = row['entity_name']
                entity_value, unit = extract_value_and_unit(row['entity_value'])
                if entity_value is None:
                    continue  # Skip invalid rows
                
                entity_type_idx = entity_types.index(entity_name)
                
                # Forward pass
                optimizer.zero_grad()
                entity_type_logits, predicted_value = model(image_tensor.unsqueeze(0))
                
                # Compute losses
                loss_entity = criterion_entity(entity_type_logits, torch.tensor([entity_type_idx]))
                loss_value = criterion_value(predicted_value, torch.tensor([entity_value]))
                
                loss = loss_entity + loss_value
                
                # Backward and optimize
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
            except Exception as e:
                print(f"Error processing row {index}: {str(e)}")
        
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_data)}")


    torch.save(model.state_dict(), 'entity_value_predictor.pth')
    print(f"Model saved at 'entity_value_predictor.pth'")


In [11]:
# Initialize the model, loss functions, and optimizer
entity_types = list(src.constants.entity_unit_map.keys())
num_entity_types = len(entity_types)
model = EntityValuePredictor(num_entity_types)



In [12]:
criterion_entity = nn.CrossEntropyLoss()  # For entity type classification
criterion_value = nn.MSELoss()  # For entity value regression
# optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.AdamW(model.parameters(), lr=0.0005)

In [13]:
# Train the model using train.csv
train_data_path = 'dataset/sample_train.csv'
train_model(train_data_path, model, criterion_entity, criterion_value, optimizer, entity_types)


  0%|          | 0/320 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)
 49%|████▉     | 158/320 [02:17<18:05,  6.70s/it]

Error processing row 157: HTTPSConnectionPool(host='m.media-amazon.com', port=443): Max retries exceeded with url: /images/I/71pKB7-3itL.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x0000022554832FD0>, 'Connection to m.media-amazon.com timed out. (connect timeout=None)'))


 50%|████▉     | 159/320 [02:38<29:31, 11.00s/it]

Error processing row 158: HTTPSConnectionPool(host='m.media-amazon.com', port=443): Max retries exceeded with url: /images/I/71PGzqxh-EL.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x00000225548301D0>, 'Connection to m.media-amazon.com timed out. (connect timeout=None)'))


 50%|█████     | 160/320 [02:59<37:22, 14.02s/it]

Error processing row 159: HTTPSConnectionPool(host='m.media-amazon.com', port=443): Max retries exceeded with url: /images/I/71uqF3qqsuL.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x0000022554835590>, 'Connection to m.media-amazon.com timed out. (connect timeout=None)'))


 50%|█████     | 161/320 [03:20<42:43, 16.12s/it]

Error processing row 160: HTTPSConnectionPool(host='m.media-amazon.com', port=443): Max retries exceeded with url: /images/I/61vdMgkasML.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x0000022554833DD0>, 'Connection to m.media-amazon.com timed out. (connect timeout=None)'))


 67%|██████▋   | 214/320 [04:31<12:27,  7.05s/it]

Error processing row 213: HTTPSConnectionPool(host='m.media-amazon.com', port=443): Max retries exceeded with url: /images/I/71XLo+kfY4S.jpg (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x0000022554834210>, 'Connection to m.media-amazon.com timed out. (connect timeout=None)'))


100%|██████████| 320/320 [05:44<00:00,  1.08s/it]


Epoch [1/3], Loss: 183282.22927762306


100%|██████████| 320/320 [03:52<00:00,  1.38it/s]


Epoch [2/3], Loss: 178517.23054127226


100%|██████████| 320/320 [03:47<00:00,  1.41it/s]

Epoch [3/3], Loss: 178746.96869110168
Model saved at 'entity_value_predictor.pth'





In [14]:
import torch
from src.constants import entity_unit_map  # Assuming this contains your entity names and allowed units

# Load the trained model
def load_model(model_path):
    entity_types = list(entity_unit_map.keys())
    num_entity_types = len(entity_types)
    
    model = EntityValuePredictor(num_entity_types)  # Initialize the model
    model.load_state_dict(torch.load(model_path))   # Load the trained weights
    model.eval()  # Set the model to evaluation mode
    return model, entity_types

# Example usage:
MODEL_PATH = 'entity_value_predictor.pth'
model, entity_types = load_model(MODEL_PATH)


  model.load_state_dict(torch.load(model_path))   # Load the trained weights


In [15]:
def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert('RGB')


def predictor(image_link, group_id, entity_name, model, entity_types):
    try:
        image = download_image(image_link)
        image_tensor = preprocess(image)
        prediction = predict(model, image_tensor, entity_name, entity_types)
        return prediction
    except Exception as e:
        print(f"Error processing {image_link}: {str(e)}")
        return ""
    
    
def predict(model, image_tensor, entity_name, entity_types):
    with torch.no_grad():
        entity_type_logits, predicted_value = model(image_tensor.unsqueeze(0))
    
    # Predict entity type
    entity_type_probs = torch.softmax(entity_type_logits, dim=1)
    predicted_entity_index = torch.argmax(entity_type_probs).item()
    predicted_entity = entity_types[predicted_entity_index]
    
    # If the predicted entity type doesn't match, return empty string
    if predicted_entity != entity_name:
        return ""  # Skip if the model predicts a different entity type
    
    # Get the predicted value
    predicted_value = predicted_value.item()
    
    # Choose the appropriate unit based on the entity type
    unit = next(iter(entity_unit_map[entity_name]))  # Get the correct unit for the entity type
    
    # Format the prediction string as "x unit"
    return f"{predicted_value:.2f} {unit}"

In [16]:
import pandas as pd
from tqdm import tqdm

# Preprocess the images and predict entity values


def process_test_data(test_data_path, output_file, model, entity_types):
    # Load the test dataset
    test_data = pd.read_csv(test_data_path)
    
    # Make predictions
    tqdm.pandas()
    test_data['prediction'] = test_data.progress_apply(
        lambda row: predictor(row['image_link'], row['group_id'], row['entity_name'], model, entity_types), axis=1)
    
    # Save predictions to output CSV file
    test_data[['index', 'prediction']].to_csv(output_file, index=False)
    print(f"Output file generated: {output_file}")

# Example usage:
TEST_DATA_PATH = 'dataset/sample_test.csv'
OUTPUT_FILE = 'dataset/test_out.csv'
process_test_data(TEST_DATA_PATH, OUTPUT_FILE, model, entity_types)


100%|██████████| 88/88 [00:37<00:00,  2.34it/s]

Output file generated: dataset/test_out.csv





# here is the another version(VIT model )to increase the performance of model

In [17]:
import os
import pandas as pd
import torch
from torch import nn
import timm  # Library for Vision Transformer
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO
import src.constants
from tqdm import tqdm

class EntityValuePredictor(nn.Module):
    def __init__(self, num_entity_types):
        super(EntityValuePredictor, self).__init__()
        # Use Vision Transformer (ViT) model from timm
        self.backbone = timm.create_model('vit_base_patch16_224', pretrained=True)
        in_features = self.backbone.get_classifier().in_features
        self.backbone.reset_classifier(0)
        self.fc1 = nn.Linear(in_features, 512)
        self.fc2 = nn.Linear(512, num_entity_types)
        self.value_head = nn.Linear(512, 1)  # Predict a single value

    def forward(self, x):
        x = self.backbone(x)
        x = torch.relu(self.fc1(x))
        entity_type_logits = self.fc2(x)
        value = self.value_head(x)
        return entity_type_logits, value

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert('RGB')

def load_model(model, model_path):
    # Load the checkpoint directly (without 'model_state_dict')
    checkpoint = torch.load(model_path)

    # Get the model's current state dictionary
    model_dict = model.state_dict()

    # Filter out layers that don't match in size
    pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict and v.size() == model_dict[k].size()}

    # Update the current model's state_dict with matching keys
    model_dict.update(pretrained_dict)

    # Load the updated state_dict into the model
    model.load_state_dict(model_dict)
    return model



def predict(model, image_tensor, entity_name, entity_types):
    with torch.no_grad():
        entity_type_logits, value = model(image_tensor.unsqueeze(0))
    
    entity_type_probs = torch.softmax(entity_type_logits, dim=1)
    predicted_entity_index = torch.argmax(entity_type_probs).item()
    predicted_entity = entity_types[predicted_entity_index]
    
    if predicted_entity != entity_name:
        return ""  # Return empty string if predicted entity doesn't match
    
    predicted_value = value.item()
    
    # Choose appropriate unit based on entity type
    unit = next(iter(src.constants.entity_unit_map[entity_name]))
    
    # Format the prediction string
    return f"{predicted_value:.2f} {unit}"

def predictor(image_link, group_id, entity_name, model, entity_types):
    try:
        image = download_image(image_link)
        image_tensor = preprocess(image)
        prediction = predict(model, image_tensor, entity_name, entity_types)
        return prediction
    except Exception as e:
        print(f"Error processing {image_link}: {str(e)}")
        return ""

if __name__ == "__main__":
    DATASET_FOLDER = 'dataset/'
    MODEL_PATH = 'entity_value_predictor.pth'  # Path to your trained model
    
    # Load the trained model
    model = EntityValuePredictor(num_entity_types)
    model = load_model(model, MODEL_PATH)
    # Load test data
    test = pd.read_csv(os.path.join(DATASET_FOLDER, 'sample_test.csv'))
    
    # Make predictions
    tqdm.pandas()
    test['prediction'] = test.progress_apply(
        lambda row: predictor(row['image_link'], row['group_id'], row['entity_name'], model, entity_types), axis=1)
    
    # Save predictions
    output_filename = os.path.join(DATASET_FOLDER, 'test_out.csv')
    test[['index', 'prediction']].to_csv(output_filename, index=False)

    print(f"Output file generated: {output_filename}")


  checkpoint = torch.load(model_path)
 33%|███▎      | 29/88 [00:14<00:54,  1.09it/s]

Error processing https://m.media-amazon.com/images/I/51BEuVR4ZzL.jpg: HTTPSConnectionPool(host='m.media-amazon.com', port=443): Max retries exceeded with url: /images/I/51BEuVR4ZzL.jpg (Caused by SSLError(SSLError(1, '[SSL: TLSV1_ALERT_INTERNAL_ERROR] tlsv1 alert internal error (_ssl.c:1002)')))


100%|██████████| 88/88 [00:38<00:00,  2.31it/s]

Output file generated: dataset/test_out.csv





In [55]:
checkpoint = torch.load(MODEL_PATH)
print(checkpoint.keys())


odict_keys(['backbone.conv1.weight', 'backbone.bn1.weight', 'backbone.bn1.bias', 'backbone.bn1.running_mean', 'backbone.bn1.running_var', 'backbone.bn1.num_batches_tracked', 'backbone.layer1.0.conv1.weight', 'backbone.layer1.0.bn1.weight', 'backbone.layer1.0.bn1.bias', 'backbone.layer1.0.bn1.running_mean', 'backbone.layer1.0.bn1.running_var', 'backbone.layer1.0.bn1.num_batches_tracked', 'backbone.layer1.0.conv2.weight', 'backbone.layer1.0.bn2.weight', 'backbone.layer1.0.bn2.bias', 'backbone.layer1.0.bn2.running_mean', 'backbone.layer1.0.bn2.running_var', 'backbone.layer1.0.bn2.num_batches_tracked', 'backbone.layer1.0.conv3.weight', 'backbone.layer1.0.bn3.weight', 'backbone.layer1.0.bn3.bias', 'backbone.layer1.0.bn3.running_mean', 'backbone.layer1.0.bn3.running_var', 'backbone.layer1.0.bn3.num_batches_tracked', 'backbone.layer1.0.downsample.0.weight', 'backbone.layer1.0.downsample.1.weight', 'backbone.layer1.0.downsample.1.bias', 'backbone.layer1.0.downsample.1.running_mean', 'backbone.

  checkpoint = torch.load(MODEL_PATH)


#### The another approach (EfficientNet model)for the model

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from efficientnet_pytorch import EfficientNet
from transformers import BertTokenizer, BertModel
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import pandas as pd
from PIL import Image
import numpy as np


In [8]:
class EfficientNetTransformerModel(nn.Module):
    def __init__(self, num_units, num_heads, hidden_dim, num_layers):
        super(EfficientNetTransformerModel, self).__init__()
        
        # Load pretrained EfficientNet model
        self.efficientnet = models.efficientnet_b0(pretrained=True)
        
        # Modify the classifier head
        num_ftrs = self.efficientnet.classifier[1].in_features
        self.efficientnet.classifier = nn.Identity()
        
        # Define transformer
        self.transformer = nn.Transformer(
            d_model=hidden_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers
        )
        
        # Define embedding and output layers
        self.embedding = nn.Linear(num_ftrs, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, num_units)

    def forward(self, images, entities, tgt):
        # Extract features from EfficientNet
        features = self.efficientnet(images)
        
        # Apply linear transformation to the features
        embedded_features = self.embedding(features)
        
        # Transformer expects input in shape (seq_len, batch, features)
        embedded_features = embedded_features.unsqueeze(0)  # Add sequence dimension
        tgt = tgt.unsqueeze(0)  # Add sequence dimension
        
        # Pass through transformer
        transformer_output = self.transformer(embedded_features, tgt)
        
        # Final output layer
        output = self.fc_out(transformer_output.squeeze(0))  # Remove sequence dimension
        return output


###### data prepareation

In [3]:
import requests
from io import BytesIO
from PIL import Image
import pandas as pd
import torch
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def load_image(image_path):
    if image_path.startswith('http'):
        response = requests.get(image_path)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_path).convert('RGB')
    image = transform(image)
    return image

def preprocess_data(data, le_entity, le_units):
    images = []
    entities = []
    values = []
    for _, row in data.iterrows():
        image = load_image(row['image_link'])
        images.append(image)
        entities.append(le_entity.transform([row['entity_name']])[0])  # Encode entity_name
        value_unit = row['entity_value'].split()[-1]  # Extract unit
        values.append(le_units.transform([value_unit])[0])
    
    images_tensor = torch.stack(images)
    entities_tensor = torch.tensor(entities)
    values_tensor = torch.tensor(values)
    return images_tensor, entities_tensor, values_tensor


###### Load Data and Encode Labels

In [4]:
# Load data
train_data = pd.read_csv('dataset/sample_train.csv')
test_data = pd.read_csv('dataset/sample_test.csv')

# Label encode entity names and units
le_entity = LabelEncoder()
train_data['entity_name_encoded'] = le_entity.fit_transform(train_data['entity_name'])

allowed_units = sorted(train_data['entity_value'].str.extract(r'([a-zA-Z]+)$')[0].unique())
le_units = LabelEncoder()
le_units.fit(allowed_units)

# Filter test data to only include known entity names
test_data = test_data[test_data['entity_name'].isin(le_entity.classes_)]

# Encode entity names in the test data
test_data['entity_name_encoded'] = le_entity.transform(test_data['entity_name'])


In [6]:
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO


train_data = pd.read_csv('dataset/sample_train.csv')
test_data = pd.read_csv('dataset/sample_test.csv')

# Label encode entity names and units
le_entity = LabelEncoder()
train_data['entity_name_encoded'] = le_entity.fit_transform(train_data['entity_name'])
allowed_units = sorted(train_data['entity_value'].str.extract(r'([a-zA-Z]+)$')[0].unique())
le_units = LabelEncoder()
le_units.fit(allowed_units)

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def load_image(image_path):
    if image_path.startswith('http'):
        response = requests.get(image_path)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_path).convert('RGB')
    image = transform(image)
    return image

def preprocess_data(data, le_entity=None, le_units=None):
    images = []
    entities = []
    values = []
    for _, row in data.iterrows():
        image = load_image(row['image_link'])
        images.append(image)
        if le_entity:
            entities.append(row['entity_name_encoded'])
        if 'entity_value' in row:
            values.append(row['entity_value'])
    
    images_tensor = torch.stack(images)
    if le_entity:
        entities_tensor = torch.tensor(entities)
    else:
        entities_tensor = torch.empty(0, dtype=torch.long)  # Placeholder for test data
    
    if le_units:
        values_tensor = torch.tensor([le_units.transform([val.split()[-1]])[0] for val in values])
    else:
        values_tensor = torch.empty(0, dtype=torch.long)  # Placeholder for test data
    
    return images_tensor, entities_tensor, values_tensor

# Prepare data
images_train, entities_train, values_train = preprocess_data(train_data, le_entity, le_units)
images_test, entities_test, _ = preprocess_data(test_data)  # Test data does not have entity_value

# Split data for training and validation
train_indices, val_indices = train_test_split(range(len(images_train)), test_size=0.2, random_state=42)
images_train, images_val = images_train[train_indices], images_train[val_indices]
entities_train, entities_val = entities_train[train_indices], entities_train[val_indices]
values_train, values_val = values_train[train_indices], values_train[val_indices]


###### train the model

In [11]:
# Initialize model, loss function, and optimizer
num_units = len(le_units.classes_)
hidden_dim = 256
num_heads = 4
num_layers = 4

model = EfficientNetTransformerModel(num_units=num_units, num_heads=num_heads, hidden_dim=hidden_dim, num_layers=num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train(model, images_train, entities_train, values_train, epochs=3):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(images_train, entities_train, entities_train)  # Modify as needed
        loss = criterion(outputs.view(-1, num_units), values_train)
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')

train(model, images_train, entities_train, values_train)




: 

###### Evaluate and Predict

In [None]:
def evaluate(model, images_val, entities_val, values_val):
    model.eval()
    with torch.no_grad():
        outputs = model(images_val, entities_val, entities_val)  # Modify as needed
        _, predicted = torch.max(outputs.view(-1, num_units), 1)
        accuracy = (predicted == values_val).float().mean()
        print(f'Validation Accuracy: {accuracy.item()}')

evaluate(model, images_val, entities_val, values_val)

def predict(model, images_test, entities_test):
    model.eval()
    with torch.no_grad():
        outputs = model(images_test, entities_test, entities_test)  # Modify as needed
        _, predicted = torch.max(outputs.view(-1, num_units), 1)
        predicted_units = le_units.inverse_transform(predicted.numpy())
        return predicted_units

predictions = predict(model, images_test, entities_test)
