# Team MSM: Mini-Hackathon for synthesis prediction

## Imports

In [4]:
import ast
import json
import os
import pickle

from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset, random_split

## Data

In [5]:
class CandidateEmbedding():

    def __init__(self, embedding_file: str = "data/formulas_to_embedding.pkl") -> None:
        with open(embedding_file, "rb") as f:
            self.formulas_to_embedding = pickle.load(f)
    
    def get_embedding(self, target_formula: str) -> np.array:
        return np.array(self.formulas_to_embedding[target_formula])


In [6]:
class TrainDataset(Dataset):
    def __init__(self, csv_file: str = "data/ground_truth_sets.csv") -> None:

        def convert_to_list_of_lists(cell):
            return ast.literal_eval(cell)
        
        self.data = pd.read_csv(csv_file, converters={1: convert_to_list_of_lists})

        self.formulas_to_embedding = CandidateEmbedding().formulas_to_embedding

        with open("data/candidates.json", "r") as f:
            self.candidates = json.load(f)
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Tuple[np.ndarray, List[int]]:
        target_formula = self.data.iloc[idx, 0]
        target_formula = self.formulas_to_embedding[target_formula]
        precursor_indexes = []
        precursor_formulas = self.data.iloc[idx, 1]
        for p in precursor_formulas:
             precursor_indexes.append(self.precursor_to_index(p))
        return target_formula, precursor_indexes

    def precursor_to_index(self, precursor_set: List[str]) -> int:
        try:
            index = self.candidates.index(precursor_set)
        except ValueError:
            index = -1  # Return -1 if the precursor_set is not found
        return index

def train_collate_fn(batch):
    target_formulas, precursor_indexes = zip(*batch)
    
    # Convert target formulas to tensors
    target_formulas = [torch.tensor(tf, dtype=torch.float32) for tf in target_formulas]
    
    # Pad precursor indexes to the same length
    max_length = max(len(pf) for pf in precursor_indexes)
    padded_precursor_indexes = [
        torch.tensor(pf + [-1] * (max_length - len(pf)), dtype=torch.long) for pf in precursor_indexes
    ]
    
    return torch.stack(target_formulas), torch.stack(padded_precursor_indexes)

In [7]:
class TestDataset(Dataset):
            def __init__(self, json_file: str = "data/test_targets.json") -> None:
                with open(json_file, 'r') as f:
                    self.data = json.load(f)
                self.formulas_to_embedding = CandidateEmbedding().formulas_to_embedding
            
            def __len__(self) -> int:
                return len(self.data)
            
            def __getitem__(self, idx: int) -> np.ndarray:
                target_formula = self.data[idx]
                target_formula = self.formulas_to_embedding[target_formula]
                return target_formula

## Model

In [8]:
class SynthesisPredictionModel(nn.Module):
    def __init__(self, input_dim: int=512, hidden_dim: int=1024, output_dim: int=27106):
        super(SynthesisPredictionModel, self).__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.hidden_layer = nn.Linear(hidden_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, target_formula: np.ndarray) -> Tensor:
        x = F.relu(self.input_layer(target_formula))
        x = F.relu(self.hidden_layer(x))
        x = self.output_layer(x)
        return x

## Loss

Output of the last layer: Vector 27106 x 1

Target in the ground truth: Sets of materials, have to be translated to [m, n, ...]

Test target: Embeddings


In [9]:
class CustomRankLoss(nn.Module):
    def __init__(self, margin: float=10.0):
        """
        Custom loss to ensure the highest logits correspond to the correct indices.
        Args:
            margin (float): Minimum margin by which correct logits must exceed incorrect logits.
        """
        super(CustomRankLoss, self).__init__()
        self.margin = margin

    def forward(self, logits, padded_correct_indices):
        """
        Args:
            logits (torch.Tensor): Model output of shape (batch_size, num_classes).
            correct_indices (List[torch.Tensor]): List of tensors where each tensor contains the correct indices for each example in the batch.
        Returns:
            torch.Tensor: Computed loss.
        """
        batch_size = logits.size(0)
        loss = 0.0

        for i in range(batch_size):
            correct_indices = padded_correct_indices[i]
            valid_indices = correct_indices[correct_indices >= 0]
            correct_logits = logits[i, valid_indices]  # Logits at correct indices

            # Get the logits for all other indices (incorrect logits)
            incorrect_logits = logits[i]
            incorrect_logits = incorrect_logits[
                torch.isin(
                    elements=torch.arange(logits.size(1), device=logits.device), 
                    test_elements=valid_indices, 
                    invert=True,
                    )
                ]  # Remove correct indices
            
            # print(f"Length of output vector: {len(logits[i])}")
            # print(f"Length of correct indices: {len(correct_indices)}")
            # print(correct_indices)
            # print(f"Length of valid indices: {len(valid_indices)}")
            # print(valid_indices)
            # print(f"Length of correct logits: {len(correct_logits)}")
            # print(correct_logits)
            # print(f"Length of incorrect logits: {len(incorrect_logits)}")
            # print("")
            
            # Margin-based ranking loss
            # Make sure to unsqueeze dimensions for broadcasting (correct_logits: (num_correct, 1), incorrect_logits: (num_incorrect,))
            pairwise_losses = torch.relu(self.margin + incorrect_logits.unsqueeze(0) - correct_logits.unsqueeze(1))
            
            # Mean pairwise loss for this example
            loss += pairwise_losses.mean()

        return loss / batch_size

## Eval

In [10]:
def mean_reciprocal_rank(predicted_indices, correct_indices):
    """
    Calculate the Mean Reciprocal Rank (MRR) for the batch.
    
    Args:
        predicted_indices (torch.Tensor): Tensor of shape (batch_size, num_classes) with the predicted ranks.
        correct_indices (torch.Tensor): Tensor of shape (batch_size, num_correct) with the correct indices.
    
    Returns:
        float: Mean Reciprocal Rank for the batch.
    """
    batch_size = predicted_indices.size(0)
    reciprocal_ranks = []
    
    for i in range(batch_size):
        # Get the rank of the first correct index in the sorted predictions
        correct_index_set = set(correct_indices[i].tolist())
        for rank, idx in enumerate(predicted_indices[i].tolist(), start=1):
            if idx in correct_index_set:
                reciprocal_ranks.append(1.0 / rank)
                break
        else:
            reciprocal_ranks.append(0.0)  # No correct index found in the predictions
    
    return sum(reciprocal_ranks) / batch_size

In [11]:
def evaluate_model(model, dataloader, device):
    model.eval()
    total_mrr = 0.0
    batch_count = 0

    with torch.no_grad():
        for batch_idx, (target_formulas, padded_precursor_indexes) in enumerate(dataloader):
            # Move data to the same device as the model
            target_formulas = target_formulas.to(device)
            padded_precursor_indexes = [indices.to(device) for indices in padded_precursor_indexes]

            # Forward pass
            logits = model(target_formulas)  # Shape: (batch_size, output_dim)

            # Get the predicted indices sorted by logits (descending order)
            _, predicted_indices = torch.sort(logits, descending=True)

            # Calculate MRR
            batch_mrr = mean_reciprocal_rank(predicted_indices, torch.stack(padded_precursor_indexes))
            total_mrr += batch_mrr
            batch_count += 1

    avg_mrr = total_mrr / batch_count
    print(f"Evaluation Complete. Average MRR: {avg_mrr:.4f}")
    return avg_mrr

## Training Loop

In [12]:
num_epochs = 100
log_interval = 1
save_interval = 10
eval_interval = 1

checkpoints_dir = "checkpoints"
os.makedirs(checkpoints_dir, exist_ok=True)

train_dataset = TrainDataset()
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=train_collate_fn)

test_dataset = TestDataset()
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

model = SynthesisPredictionModel()

# Define optimizer and custom loss function
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = CustomRankLoss(margin=100.0)



# Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    batch_count = 0

    for batch_idx, (target_formulas, padded_precursor_indexes) in enumerate(train_dataloader):
        # Move data to the same device as the model
        target_formulas = target_formulas.to(device)
        padded_precursor_indexes = [indices.to(device) for indices in padded_precursor_indexes]

        # Forward pass
        logits = model(target_formulas)  # Shape: (batch_size, output_dim)

        # Compute loss
        loss = criterion(logits, padded_precursor_indexes)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        batch_count += 1

        # Log progress
        # if (batch_idx + 1) % log_interval == 0:
        #     print(
        #         f"Epoch [{epoch + 1}/{num_epochs}], "
        #         f"Batch [{batch_idx + 1}/{len(train_dataloader)}], "
        #         f"Loss: {loss.item():.4f}"
        #     )


    avg_loss = total_loss / batch_count
    print(f"Epoch [{epoch + 1}/{num_epochs}] Complete. Average Loss: {avg_loss:.4f}")

    # Evaluate the model on the train set every 5 epochs
    if (epoch + 1) % eval_interval == 0:
        evaluate_model(model, train_dataloader, device)

    if (epoch + 1) % save_interval == 0:
        checkpoint_path = os.path.join(checkpoints_dir, f"model_epoch_{epoch + 1}.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Model saved to {checkpoint_path}")


Epoch [1/100] Complete. Average Loss: 87.2347
Evaluation Complete. Average MRR: 0.0126
Epoch [2/100] Complete. Average Loss: 19.4179
Evaluation Complete. Average MRR: 0.0131
Epoch [3/100] Complete. Average Loss: 14.8596
Evaluation Complete. Average MRR: 0.0132
Epoch [4/100] Complete. Average Loss: 14.2317
Evaluation Complete. Average MRR: 0.0137
Epoch [5/100] Complete. Average Loss: 13.8618
Evaluation Complete. Average MRR: 0.0140
Epoch [6/100] Complete. Average Loss: 13.5755


KeyboardInterrupt: 

In [None]:
# Training loop parameters
num_epochs = 100
log_interval = 1
save_interval = 10
eval_interval = 1

checkpoints_dir = "checkpoints2"
os.makedirs(checkpoints_dir, exist_ok=True)

# Load dataset
dataset = TrainDataset()

# Split dataset into 80% train and 20% evaluate
train_size = int(0.8 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=train_collate_fn)
eval_dataloader = DataLoader(eval_dataset, batch_size=4, shuffle=False, collate_fn=train_collate_fn)

model = SynthesisPredictionModel()

# Define optimizer and custom loss function
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = CustomRankLoss(margin=100.0)

# Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    batch_count = 0

    for batch_idx, (target_formulas, padded_precursor_indexes) in enumerate(train_dataloader):
        # Move data to the same device as the model
        target_formulas = target_formulas.to(device)
        padded_precursor_indexes = [indices.to(device) for indices in padded_precursor_indexes]

        # Forward pass
        logits = model(target_formulas)  # Shape: (batch_size, output_dim)

        # Compute loss
        loss = criterion(logits, padded_precursor_indexes)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        batch_count += 1

    avg_loss = total_loss / batch_count
    print(f"Epoch [{epoch + 1}/{num_epochs}] Complete. Average Loss: {avg_loss:.4f}")

    # Evaluate the model on the eval set every eval_interval epochs
    if (epoch + 1) % eval_interval == 0:
        evaluate_model(model, eval_dataloader, device)

    # Save the model checkpoint every save_interval epochs
    if (epoch + 1) % save_interval == 0:
        checkpoint_path = os.path.join(checkpoints_dir, f"model_epoch_{epoch + 1}.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Model saved to {checkpoint_path}")

Epoch [1/100] Complete. Average Loss: 89.7008
Evaluation Complete. Average MRR: 0.0075
Epoch [2/100] Complete. Average Loss: 17.5271
Evaluation Complete. Average MRR: 0.0087
Epoch [3/100] Complete. Average Loss: 13.0364
Evaluation Complete. Average MRR: 0.0092
Epoch [4/100] Complete. Average Loss: 12.4751
Evaluation Complete. Average MRR: 0.0101
Epoch [5/100] Complete. Average Loss: 12.1896
Evaluation Complete. Average MRR: 0.0103
Epoch [6/100] Complete. Average Loss: 11.9419
Evaluation Complete. Average MRR: 0.0110
Epoch [7/100] Complete. Average Loss: 11.7364
Evaluation Complete. Average MRR: 0.0114
Epoch [8/100] Complete. Average Loss: 11.5087
Evaluation Complete. Average MRR: 0.0118
Epoch [9/100] Complete. Average Loss: 11.1749
Evaluation Complete. Average MRR: 0.0156
Epoch [10/100] Complete. Average Loss: 10.4544
Evaluation Complete. Average MRR: 0.0197
Model saved to checkpoints2/model_epoch_10.pth
Epoch [11/100] Complete. Average Loss: 9.1984
Evaluation Complete. Average MRR: 0.