In [None]:
# main_transformer_ti_generator.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import logging
import time
import pickle
import os
import random

# --- Import your new modules ---
# These would be new files you create
# from crystal_tokenizer import CrystalTokenizer 
# from ctg_model import CrystalTransformerGeneratorModel # Your Transformer for generation
# from mcc_model import MultitaskCrystalClassifierModel # Your GNN/Transformer for classification

# --- Placeholder for new modules (define these in separate files) ---
class CrystalTokenizer:
    def __init__(self, vocab_path, config):
        self.vocab = self.load_vocab(vocab_path)
        self.config = config
        # <SOS>, <EOS>, <PAD>, <UNK> tokens, etc.
        self.sos_token = "<SOS>"
        self.eos_token = "<EOS>"
        self.pad_token = "<PAD>"
        # TODO: Load actual vocabulary and define token-to-id and id-to-token mappings
        print("Mock CrystalTokenizer initialized.")

    def load_vocab(self, vocab_path):
        # In a real implementation, load from a file (e.g., JSON)
        # For this skeleton, it's a mock
        mock_vocab = {
            "<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3,
            "<sg_1>": 4, "<sg_225>": 5, # ... all 230 space groups
            "<Fe>": 234, "<O>": 235, "<Si>":236, # ... all elements
            "<w_4a>": 500, "<w_8b>": 501, # ... all Wyckoff symbols
            "<length_bin_001>": 3000, "<length_bin_300>": 3299, # ... discretized lengths
            "<angle_bin_001>": 4000, "<angle_bin_080>": 4079, # ... discretized angles
            "<fcoord_bin_001>": 5000, "<fcoord_bin_100>": 5099, # ... discretized frac coords
        }
        # Add more mock tokens as needed for the skeleton to run
        for i in range(2, 225): # Add more sg tokens
            mock_vocab[f"<sg_{i}>"] = 5 + (i-1)
        return mock_vocab

    def sequence_to_structure(self, token_id_sequence):
        # Convert a sequence of token IDs back to a pymatgen Structure object
        # This is a complex process involving parsing space group, lattice, Wyckoff sites, elements, coords
        # For the skeleton, we'll return a mock structure or None
        print(f"Mock decoding sequence: {token_id_sequence[:10]}...")
        # In a real scenario, you'd use pymatgen here
        try:
            # Mock logic: if it sees common tokens, pretend it's a valid structure
            if 1 in token_id_sequence and 2 in token_id_sequence and 5 in token_id_sequence: # SOS, EOS, SG_225
                # This is where you'd use pymatgen to construct the structure
                # from pymatgen.core import Structure, Lattice, Element
                # lattice = Lattice.cubic(5.0)
                # species = [Element("Fe"), Element("O")]
                # coords = [[0,0,0], [0.5,0.5,0.5]]
                # return Structure(lattice, species, coords)
                return "mock_pymatgen_structure_object" # Placeholder
            else: 
                return None
        except Exception as e:
            logger.warning(f"Failed to decode sequence: {e}")
            return None

    def structure_to_sequence_ids(self, structure_obj):
        # Convert a pymatgen Structure object to a sequence of token IDs
        # For the skeleton, return a mock sequence
        # Example: [<SOS_id>, <sg_225_id>, <len_bin_X_id>, ..., <Fe_id>, <w_4a_id>, ..., <EOS_id>]
        mock_sequence_ids = [
            self.vocab[self.sos_token], self.vocab["<sg_225>"], self.vocab["<length_bin_001>"],
            self.vocab["<length_bin_001>"], self.vocab["<length_bin_001>"],
            self.vocab["<angle_bin_001>"], self.vocab["<angle_bin_001>"], self.vocab["<angle_bin_001>"],
            self.vocab["<Fe>"], self.vocab["<w_4a>"],
            self.vocab["<fcoord_bin_001>"], self.vocab["<fcoord_bin_001>"], self.vocab["<fcoord_bin_001>"],
            self.vocab[self.eos_token]
        ]
        return mock_sequence_ids

class CrystalTransformerGeneratorModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, config):
        super().__init__()
        self.config = config
        self.vocab_size = vocab_size
        # Mock transformer layers - in reality, use nn.TransformerDecoder or similar
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = nn.Parameter(torch.randn(1, config.get('max_seq_len', 256), embed_dim))
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        print("Mock CTGModel initialized.")

    def forward(self, tgt_token_ids, memory=None, tgt_mask=None, tgt_key_padding_mask=None):
        # tgt_token_ids: (batch_size, seq_len)
        tgt_embed = self.token_embedding(tgt_token_ids) + self.positional_encoding[:, :tgt_token_ids.size(1), :]
        
        # For decoder-only, memory is not typically used from an encoder.
        # If it's a true nn.TransformerDecoder, it expects memory.
        # For a GPT-like model, you'd implement causal attention directly or use a pre-built one.
        # This is a simplified skeleton.
        # A true GPT-style model wouldn't use nn.TransformerDecoder in this way without an encoder's memory.
        # It would use multiple nn.TransformerDecoderLayer with causal masking.
        # For simplicity, let's assume memory is not used or is self-referential for generation.
        # This part needs careful implementation based on chosen transformer type (e.g. decoder-only GPT style)
        
        # Simplified pass for skeleton:
        # This is NOT a correct way to use nn.TransformerDecoder for autoregressive generation
        # without proper causal masking and memory handling.
        # A proper implementation would involve generating a causal mask.
        if memory is None: # For a decoder-only model, memory would be the target itself for self-attention
             memory = tgt_embed 
        
        # Create a causal mask if not provided
        if tgt_mask is None:
            seq_len = tgt_token_ids.size(1)
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(tgt_token_ids.device)

        output = self.transformer_decoder(tgt_embed, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)
        logits = self.fc_out(output)
        return logits # (batch_size, seq_len, vocab_size)

    @torch.no_grad()
    def generate_sequences(self, batch_size, start_token_id, eos_token_id, max_len, device):
        # Autoregressive generation
        sequences = torch.full((batch_size, 1), start_token_id, dtype=torch.long, device=device)
        log_probs_list = [[] for _ in range(batch_size)]
        ended = [False] * batch_size
        
        for _ in range(max_len - 1):
            current_seq_len = sequences.size(1)
            # Create padding mask for current sequences
            # All current tokens are valid, so no padding mask needed for this step if all sequences are same length
            
            logits = self(sequences) # Get logits for the whole sequence so far
            next_token_logits = logits[:, -1, :] # Get logits for the next token: (batch_size, vocab_size)
            
            probs = F.softmax(next_token_logits, dim=-1)
            next_token_ids = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
            
            # Store log_probs for RL
            for i in range(batch_size):
                if not ended[i]:
                    log_prob = torch.log(probs[i, next_token_ids[i, 0]])
                    log_probs_list[i].append(log_prob)

            sequences = torch.cat([sequences, next_token_ids], dim=1)

            for i in range(batch_size):
                if not ended[i] and next_token_ids[i,0] == eos_token_id:
                    ended[i] = True
            
            if all(ended):
                break
        
        # Sum log_probs for each sequence
        final_log_probs = torch.tensor([sum(lp).item() if lp else 0.0 for lp in log_probs_list], device=device)
        return sequences, final_log_probs


class MultitaskCrystalClassifierModel(nn.Module):
    def __init__(self, input_feat_dim, hidden_dim, num_layers, mcc_config):
        super().__init__()
        self.mcc_config = mcc_config
        # Mock GNN layers - in reality, use something like GCNConv, GATConv, or a graph transformer
        self.gnn_layers = nn.ModuleList([nn.Linear(input_feat_dim if i == 0 else hidden_dim, hidden_dim) for i in range(num_layers)])
        self.relu = nn.ReLU()
        
        # Output heads
        self.topology_head = nn.Linear(hidden_dim, mcc_config.get('num_topo_classes', 2)) # e.g., TI / not TI
        self.magnetism_head = nn.Linear(hidden_dim, mcc_config.get('num_mag_classes', 3)) # e.g., FM / AFM / Para
        self.energy_head = nn.Linear(hidden_dim, 1) # Formation energy
        print("Mock MCCModel initialized.")

    def forward(self, graph_batch): # graph_batch would be a Batch object from torch_geometric
        # x: node features (batch_num_nodes, node_feat_dim)
        # edge_index: graph connectivity (2, batch_num_edges)
        # batch: assigns each node to its graph (batch_num_nodes)
        # This is a highly simplified GNN
        
        # In a real GNN, you'd use message passing layers.
        # For skeleton: pass dummy features through linear layers.
        # Assume graph_batch is a tensor of global graph features for simplicity here.
        # x = graph_batch # (batch_size, input_feat_dim)
        
        # A more realistic GNN forward pass (conceptual)
        # x, edge_index, batch_idx = graph_batch.x, graph_batch.edge_index, graph_batch.batch
        # for layer in self.gnn_layers:
        #     x = self.relu(layer(x, edge_index)) # Pass edge_index to GNN layers
        # # Global pooling (e.g., mean pooling per graph)
        # from torch_geometric.nn import global_mean_pool
        # x_pooled = global_mean_pool(x, batch_idx) # (batch_size, hidden_dim)

        # SIMPLIFIED SKELETON: Assume graph_batch is already pooled features
        x_pooled = graph_batch # (batch_size, hidden_dim) - this is a placeholder!
        for layer in self.gnn_layers:
             x_pooled = self.relu(layer(x_pooled))


        topo_logits = self.topology_head(x_pooled)
        mag_logits = self.magnetism_head(x_pooled)
        energy_pred = self.energy_head(x_pooled)
        
        return {
            'topology_logits': topo_logits,    # (batch_size, num_topo_classes)
            'magnetism_logits': mag_logits,  # (batch_size, num_mag_classes)
            'formation_energy': energy_pred # (batch_size, 1)
        }

# --- Replay Buffer (can be largely reused) ---
class ReplayBuffer:
    def __init__(self, max_size=10000):
        self.max_size = max_size
        self.buffer = []
        self.position = 0   
        
    def add(self, sequence, reward, log_prob): # Store sequence instead of z
        if len(self.buffer) < self.max_size:
            self.buffer.append(None)
        # Store sequence as list of IDs, reward, and log_prob
        self.buffer[self.position] = (sequence.cpu().numpy().tolist(), reward, log_prob.cpu().item())
        self.position = (self.position + 1) % self.max_size
        
    def sample(self, batch_size):
        actual_batch_size = min(batch_size, len(self.buffer))
        if actual_batch_size == 0:
            return [], [], []
        batch = random.sample(self.buffer, actual_batch_size)
        sequences, rewards, log_probs = zip(*batch)
        # Sequences are lists of IDs, rewards and log_probs are scalars
        return list(sequences), np.array(rewards), np.array(log_probs)
        
    def __len__(self):
        return len(self.buffer)

# --- Main Generator Class ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class TransformerTIGenerator:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch.float32 
        logger.info(f"Using device: {self.device}")

        # Initialize Tokenizer
        self.tokenizer = CrystalTokenizer(
            vocab_path=config.get('ctg_vocab_path', 'mock_vocab.json'), 
            config=config
        )
        self.vocab_size = len(self.tokenizer.vocab)
        self.sos_token_id = self.tokenizer.vocab[self.tokenizer.sos_token]
        self.eos_token_id = self.tokenizer.vocab[self.tokenizer.eos_token]
        self.pad_token_id = self.tokenizer.vocab[self.tokenizer.pad_token]


        # Initialize CTG (Crystal Transformer Generator)
        # These would be loaded from pre-trained checkpoints in a real scenario
        self.ctg = CrystalTransformerGeneratorModel(
            vocab_size=self.vocab_size,
            embed_dim=config.get('ctg_embed_dim', 256),
            num_heads=config.get('ctg_num_heads', 8),
            num_layers=config.get('ctg_num_layers', 6),
            config=config
        ).to(self.device)
        # self._load_model(self.ctg, config.get('ctg_checkpoint_path')) # Implement this

        # Initialize MCC (Multitask Crystal Classifier)
        self.mcc = MultitaskCrystalClassifierModel(
            input_feat_dim=config.get('mcc_input_feat_dim', 128), # Example: features from a graph node
            hidden_dim=config.get('mcc_hidden_dim', 256),
            num_layers=config.get('mcc_num_layers', 4),
            mcc_config=config.get('mcc_specific_config', {})
        ).to(self.device)
        # self._load_model(self.mcc, config.get('mcc_checkpoint_path')) # Implement this

        # Optimizers (primarily for RL fine-tuning of CTG)
        # The CTG itself is the policy network in this new setup
        self.ctg_optimizer = torch.optim.Adam(
            self.ctg.parameters(),
            lr=config.get('ctg_rl_lr', 5e-5), # Potentially smaller LR for fine-tuning
            weight_decay=config.get('ctg_weight_decay', 1e-6)
        )
        
        # Critic Network (optional, if using Actor-Critic for RL fine-tuning CTG)
        # Input to critic needs to be decided: embedding of sequence? features from MCC?
        self.critic_input_dim = config.get('ctg_embed_dim', 256) # Example: use CTG's output embedding
        if self.config.get('use_critic_rl', True):
            # Reusing your CriticNetwork definition, but input might change
            self.critic = CriticNetwork( 
                input_dim=self.critic_input_dim, # This needs to match what you feed the critic
                hidden_dims=config.get('critic_hidden_dims', [256, 128])
            ).to(self.device)
            self.critic_optimizer = torch.optim.Adam(
                self.critic.parameters(),
                lr=config.get('critic_lr', 1e-4)
            )
        else:
            self.critic = None
            
        self.results = {
            'rewards': [], 'best_structures_info': [], 'best_rewards': [],
            'mcc_evals': {'topology': [], 'magnetism': [], 'energy': []} # Store MCC outputs
        }
        self.replay_buffer = ReplayBuffer(config.get('buffer_size', 1000))

    def _load_model(self, model, checkpoint_path):
        if checkpoint_path and os.path.exists(checkpoint_path):
            try:
                model.load_state_dict(torch.load(checkpoint_path, map_location=self.device))
                logger.info(f"Loaded weights from {checkpoint_path} for {model.__class__.__name__}")
            except Exception as e:
                logger.error(f"Failed to load weights for {model.__class__.__name__} from {checkpoint_path}: {e}")
        else:
            logger.warning(f"No checkpoint path provided or path does not exist for {model.__class__.__name__}. Using initialized weights.")

    def generate_and_decode_structures(self, batch_size):
        """Generate sequences with CTG and decode them into pymatgen structures."""
        self.ctg.eval() # Set CTG to evaluation mode for generation
        
        token_sequences_ids, log_probs = self.ctg.generate_sequences(
            batch_size=batch_size,
            start_token_id=self.sos_token_id,
            eos_token_id=self.eos_token_id,
            max_len=self.config.get('max_seq_len', 256),
            device=self.device
        )
        
        decoded_structures = []
        valid_indices = [] # Keep track of sequences that decoded successfully
        
        for i, seq_ids in enumerate(token_sequences_ids):
            # Pass the tensor directly, or convert to list of python ints if tokenizer expects that
            struct_obj = self.tokenizer.sequence_to_structure(seq_ids.cpu().tolist())
            if struct_obj: # If decoding is successful
                decoded_structures.append(struct_obj)
                valid_indices.append(i)
            else:
                logger.debug(f"Sequence {i} failed to decode.")

        # Filter log_probs for successfully decoded structures
        valid_log_probs = log_probs[torch.tensor(valid_indices, device=self.device)] if valid_indices else torch.tensor([], device=self.device)
        
        return decoded_structures, token_sequences_ids[valid_indices], valid_log_probs # Return pymatgen objects, their sequences, and log_probs

    def evaluate_structures_with_mcc(self, pymatgen_structures_list):
        """Evaluate decoded structures using the MCC."""
        if not pymatgen_structures_list:
            return {'topology_probs': [], 'magnetism_probs': [], 'formation_energy': []} # Empty results

        self.mcc.eval() # Set MCC to evaluation mode
        batch_mcc_inputs = []
        for struct_obj in pymatgen_structures_list:
            # --- CRITICAL STEP: Convert pymatgen_structure to MCC input format ---
            # This depends heavily on your MCC's architecture (e.g., graph features for a GNN)
            # For skeleton: create dummy features
            # In a real GNN, this would involve creating a graph object (e.g., torch_geometric.data.Data)
            # For simplicity, assume MCC takes a fixed-size tensor representing the structure
            mock_mcc_input = torch.randn(1, self.config.get('mcc_input_feat_dim', 128)).to(self.device)
            batch_mcc_inputs.append(mock_mcc_input)
        
        if not batch_mcc_inputs:
             return {'topology_probs': [], 'magnetism_probs': [], 'formation_energy': []}

        # This assumes MCC can process a batch of these representations
        # If MCC takes torch_geometric Batch objects, you'd use `torch_geometric.data.Batch.from_data_list()`
        mcc_input_tensor = torch.cat(batch_mcc_inputs, dim=0)

        with torch.no_grad():
            predictions = self.mcc(mcc_input_tensor) # Get dict of predictions

        # Process predictions (e.g., apply softmax for classification)
        evaluations = {
            'topology_probs': F.softmax(predictions['topology_logits'], dim=-1).cpu().numpy() if 'topology_logits' in predictions else [],
            'magnetism_probs': F.softmax(predictions['magnetism_logits'], dim=-1).cpu().numpy() if 'magnetism_logits' in predictions else [],
            'formation_energy': predictions['formation_energy'].cpu().numpy().squeeze() if 'formation_energy' in predictions else []
        }
        return evaluations

    def calculate_rewards(self, evaluations):
        """Calculate rewards based on MCC evaluations (adapt your existing logic)."""
        # Example: Reward for being a TI and having low energy
        # This is highly dependent on your goals and MCC output.
        
        topo_probs = evaluations.get('topology_probs', np.array([]))
        formation_energies = evaluations.get('formation_energy', np.array([]))

        if topo_probs.size == 0 or formation_energies.size == 0:
            return np.array([])

        # Assuming topo_probs is [prob_not_TI, prob_TI]
        ti_probability = topo_probs[:, 1] if topo_probs.ndim > 1 and topo_probs.shape[1] > 1 else topo_probs

        # Reward for being TI (higher prob is better)
        reward_ti = ti_probability 
        
        # Reward for stability (lower energy is better, normalize or scale)
        # Simple example: -energy, maybe clip or scale
        reward_stability = -formation_energies 
        # Normalize stability reward if energies vary a lot
        if len(reward_stability) > 1 and np.std(reward_stability) > 1e-6:
             reward_stability = (reward_stability - np.mean(reward_stability)) / (np.std(reward_stability) + 1e-6)

        # Combine rewards (example weights)
        w_ti = self.config.get('w_ti_reward', 2.0)
        w_stability = self.config.get('w_stability_reward', 1.0)
        
        total_rewards = w_ti * reward_ti + w_stability * reward_stability
        return total_rewards


    def rl_update_step(self, sequences, rewards, log_probs):
        """Perform RL update for the CTG (policy) and critic."""
        if not isinstance(rewards, torch.Tensor):
            rewards_tensor = torch.tensor(rewards, device=self.device, dtype=self.dtype)
        else:
            rewards_tensor = rewards.to(self.device).to(self.dtype)

        if not isinstance(log_probs, torch.Tensor):
            log_probs_tensor = torch.tensor(log_probs, device=self.device, dtype=self.dtype)
        else:
            log_probs_tensor = log_probs.to(self.device).to(self.dtype)


        if log_probs_tensor.numel() == 0 or rewards_tensor.numel() == 0:
            logger.warning("Empty log_probs or rewards, skipping RL update.")
            return 0.0, 0.0 if self.critic else 0.0
        
        # --- Policy (CTG) Update ---
        self.ctg.train()
        if self.critic:
            self.critic.train()
            # For critic input, we need a representation of the state (sequence)
            # This is a placeholder: assumes critic can take the raw sequence tensor (padded)
            # Or, use embeddings from CTG. For now, let's assume a method to get state features.
            # This part needs careful design: what represents the "state" for the critic?
            # Let's use a dummy input for the critic for now.
            # A better approach might be to use the hidden states of the CTG at <EOS> or an average.
            
            # Pad sequences for critic input if they are of variable length
            # This requires knowing the max_len used for generation or a config param
            padded_sequences = nn.utils.rnn.pad_sequence(
                [torch.tensor(s, device=self.device) for s in sequences], 
                batch_first=True, 
                padding_value=self.pad_token_id
            )
            # This is still problematic as critic expects fixed size input unless it's also a sequence model
            # For skeleton: let's assume critic takes a fixed size input (e.g. mean embedding)
            # This is a placeholder for getting critic input features from sequences
            critic_input_features = torch.randn(len(sequences), self.critic_input_dim, device=self.device)


            value_predictions = self.critic(critic_input_features).squeeze()
            advantages = rewards_tensor - value_predictions.detach()
            policy_loss = -(log_probs_tensor * advantages).mean()
            
            critic_loss = F.mse_loss(value_predictions, rewards_tensor)
            
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            if self.config.get('clip_grad', True):
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.config.get('max_grad_norm', 1.0))
            self.critic_optimizer.step()
        else: # REINFORCE
            # Normalize rewards
            rewards_normalized = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-8)
            policy_loss = -(log_probs_tensor * rewards_normalized).mean()
            critic_loss = torch.tensor(0.0)

        self.ctg_optimizer.zero_grad()
        policy_loss.backward()
        if self.config.get('clip_grad', True):
            torch.nn.utils.clip_grad_norm_(self.ctg.parameters(), self.config.get('max_grad_norm', 1.0))
        self.ctg_optimizer.step()
        
        return policy_loss.item(), critic_loss.item() if self.critic else 0.0


    def train_rl_step(self):
        """Perform a single RL training step for fine-tuning CTG."""
        batch_size = self.config.get('rl_batch_size', 16)

        # 1. Generate structures (sequences) and get log_probs from CTG
        # decoded_structures_pmg is a list of pymatgen Structure objects (or your mock string)
        # generated_sequences_ids is a tensor of token IDs (batch, seq_len) for valid structures
        # current_log_probs is a tensor of summed log_probs for these valid sequences (batch,)
        decoded_structures_pmg, generated_sequences_ids, current_log_probs = self.generate_and_decode_structures(batch_size)

        if not decoded_structures_pmg:
            logger.warning("No structures were successfully generated or decoded in this step.")
            return {'mean_reward': 0, 'policy_loss': 0, 'critic_loss': 0}

        # 2. Evaluate structures with MCC
        evaluations = self.evaluate_structures_with_mcc(decoded_structures_pmg)

        # 3. Calculate rewards
        total_rewards_np = self.calculate_rewards(evaluations) # numpy array
        
        if total_rewards_np.size == 0:
            logger.warning("No rewards calculated, likely due to evaluation issues.")
            return {'mean_reward': 0, 'policy_loss': 0, 'critic_loss': 0}

        # Add to replay buffer (store original sequences and rewards)
        # generated_sequences_ids is already filtered to valid ones
        for i in range(len(decoded_structures_pmg)):
            # replay_buffer expects sequence as list of python ints, reward as float, log_prob as float
            self.replay_buffer.add(
                generated_sequences_ids[i], # This is a tensor for a single sequence
                total_rewards_np[i], 
                current_log_probs[i] # This is a tensor for a single sequence's summed log_prob
            )
        
        # Sample from replay buffer for update if buffer is large enough
        if len(self.replay_buffer) < self.config.get('rl_min_buffer_for_update', batch_size // 2):
             logger.info(f"Replay buffer size {len(self.replay_buffer)}, need {self.config.get('rl_min_buffer_for_update', batch_size // 2)} to update.")
             return {'mean_reward': np.mean(total_rewards_np) if total_rewards_np.size > 0 else 0, 
                     'policy_loss': 0, 'critic_loss': 0}

        sampled_sequences, sampled_rewards, sampled_log_probs = self.replay_buffer.sample(batch_size)
        
        if not sampled_sequences:
            logger.warning("Replay buffer sampled an empty batch.")
            return {'mean_reward': np.mean(total_rewards_np) if total_rewards_np.size > 0 else 0, 
                     'policy_loss': 0, 'critic_loss': 0}

        # 4. Perform RL update
        # sampled_sequences are lists of IDs, sampled_rewards & sampled_log_probs are numpy arrays
        policy_loss, critic_loss = self.rl_update_step(sampled_sequences, sampled_rewards, sampled_log_probs)
        
        # Track results
        mean_reward_current_batch = np.mean(total_rewards_np)
        self.results['rewards'].append(mean_reward_current_batch)
        
        # Store MCC evaluations for analysis
        for key in self.results['mcc_evals']:
            if key in evaluations and isinstance(evaluations[key], np.ndarray) and evaluations[key].size > 0:
                self.results['mcc_evals'][key].append(np.mean(evaluations[key], axis=0)) # Mean over batch

        # Track best structures (simplified)
        if total_rewards_np.size > 0:
            best_idx_current_batch = np.argmax(total_rewards_np)
            best_reward_current_batch = total_rewards_np[best_idx_current_batch]

            if not self.results['best_rewards'] or best_reward_current_batch > max(self.results['best_rewards']):
                self.results['best_rewards'].append(best_reward_current_batch)
                # Store info about the best structure (e.g., its sequence or key properties)
                best_struct_info = {
                    'sequence': generated_sequences_ids[best_idx_current_batch].cpu().tolist(),
                    'reward': best_reward_current_batch,
                    'mcc_topo_prob': evaluations['topology_probs'][best_idx_current_batch] if evaluations['topology_probs'].size > 0 else None,
                    'mcc_energy': evaluations['formation_energy'][best_idx_current_batch] if evaluations['formation_energy'].size > 0 else None,
                }
                self.results['best_structures_info'].append(best_struct_info)
                logger.info(f"New best structure found with reward: {best_reward_current_batch:.4f}")
        
        return {
            'mean_reward': mean_reward_current_batch,
            'policy_loss': policy_loss,
            'critic_loss': critic_loss
        }

    def train_rl_fine_tuning(self, num_iterations=None):
        if num_iterations is None:
            num_iterations = self.config.get('num_rl_iterations', 1000)
        
        logger.info(f"Starting RL fine-tuning for {num_iterations} iterations")
        
        for iteration in range(num_iterations):
            step_results = self.train_rl_step()
            
            if iteration % self.config.get('log_frequency', 10) == 0:
                logger.info(
                    f"Iter {iteration} | Mean Reward: {step_results.get('mean_reward',0):.3f} | "
                    f"Policy Loss: {step_results.get('policy_loss',0):.4f} | "
                    f"Critic Loss: {step_results.get('critic_loss',0):.4f}"
                )
            
            if iteration % self.config.get('save_frequency', 100) == 0 and iteration > 0:
                self.save_checkpoint(f"rl_checkpoint_iter_{iteration}.pt")
        
        logger.info("RL fine-tuning completed.")
        self.save_checkpoint("final_rl_checkpoint.pt")
        self.save_results("rl_training_results.pkl")

    def save_checkpoint(self, filename):
        checkpoint_dir = self.config.get('checkpoint_dir', './checkpoints_transformer_ti')
        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, filename)
        
        save_obj = {
            'ctg_state_dict': self.ctg.state_dict(),
            'ctg_optimizer_state_dict': self.ctg_optimizer.state_dict(),
            'config': self.config,
            'iteration': len(self.results['rewards']),
            'results': self.results 
        }
        if self.critic:
            save_obj['critic_state_dict'] = self.critic.state_dict()
            save_obj['critic_optimizer_state_dict'] = self.critic_optimizer.state_dict()
        
        # Note: MCC is assumed to be pre-trained and fixed during RL, so not saved here.
        # If MCC is also fine-tuned, add its state here.
        torch.save(save_obj, checkpoint_path)
        logger.info(f"Saved RL fine-tuning checkpoint to {checkpoint_path}")

    def load_rl_checkpoint(self, checkpoint_path):
        if not os.path.exists(checkpoint_path):
            logger.error(f"RL Checkpoint path {checkpoint_path} does not exist.")
            return 0
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            self.ctg.load_state_dict(checkpoint['ctg_state_dict'])
            self.ctg_optimizer.load_state_dict(checkpoint['ctg_optimizer_state_dict'])
            if self.critic and 'critic_state_dict' in checkpoint:
                self.critic.load_state_dict(checkpoint['critic_state_dict'])
                self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])
            
            self.results = checkpoint.get('results', self.results) # Load past results
            iter_num = checkpoint.get('iteration', 0)
            logger.info(f"Loaded RL fine-tuning checkpoint from {checkpoint_path} at iteration {iter_num}")
            return iter_num
        except Exception as e:
            logger.error(f"Failed to load RL checkpoint: {e}")
            return 0
            
    def save_results(self, filename):
        results_dir = self.config.get('results_dir', './results_transformer_ti')
        os.makedirs(results_dir, exist_ok=True)
        results_path = os.path.join(results_dir, filename)
        with open(results_path, 'wb') as f:
            pickle.dump(self.results, f)
        logger.info(f"Saved RL results to {results_path}")

# --- Mock Critic Network (can be adapted from your original PolicyNetwork/CriticNetwork) ---
class CriticNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dims=[256, 128], activation='relu'):
        super().__init__()
        layers = []
        current_dim = input_dim
        act_fn = {'relu': nn.ReLU(), 'tanh': nn.Tanh()}.get(activation, nn.ReLU())
        for h_dim in hidden_dims:
            layers.append(nn.Linear(current_dim, h_dim))
            layers.append(act_fn)
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, 1))
        self.model = nn.Sequential(*layers)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None: nn.init.constant_(m.bias, 0.0)

    def forward(self, state_features): # state_features: (batch_size, input_dim)
        return self.model(state_features)


if __name__ == '__main__':
    # --- Configuration for the new TransformerTIGenerator ---
    # This is a basic config; you'll need to expand it significantly
    config = {
        'ctg_vocab_path': 'path/to/your/ctg_vocab.json', # You'll create this
        'ctg_checkpoint_path': None, # 'path/to/pretrained_ctg.pt', # Pre-train CTG first
        'mcc_checkpoint_path': None, # 'path/to/pretrained_mcc.pt', # Pre-train MCC first
        
        'max_seq_len': 150, # Max length of generated crystal sequences
        'ctg_embed_dim': 128,
        'ctg_num_heads': 4,
        'ctg_num_layers': 3,
        
        'mcc_input_feat_dim': 64, # Example: if MCC takes a graph and you extract 64 features per graph
        'mcc_hidden_dim': 128,
        'mcc_num_layers': 3,
        'mcc_specific_config': {
            'num_topo_classes': 2, # TI vs not-TI
            'num_mag_classes': 3  # FM, AFM, Non-magnetic
        },

        'use_critic_rl': True,
        'critic_hidden_dims': [128, 64],
        'critic_input_dim': 128, # Should match the feature representation of a state (sequence)

        'rl_batch_size': 8, # Smaller batch for RL fine-tuning
        'rl_min_buffer_for_update': 4,
        'buffer_size': 500,
        'num_rl_iterations': 100, # Number of RL fine-tuning steps
        'log_frequency': 5,
        'save_frequency': 50,
        'ctg_rl_lr': 1e-5,
        'critic_lr': 5e-5,
        'clip_grad': True,
        'max_grad_norm': 1.0,

        'w_ti_reward': 2.5,
        'w_stability_reward': 1.5,

        'checkpoint_dir': './checkpoints_transformer_ti_skel',
        'results_dir': './results_transformer_ti_skel'
    }

    logger.info("Initializing TransformerTIGenerator...")
    ti_generator = TransformerTIGenerator(config)
    
    # --- Example: Run a few RL fine-tuning steps ---
    # In a real scenario, CTG and MCC would be pre-trained.
    # Here, we are using randomly initialized models, so the "fine-tuning"
    # will not produce meaningful results but will test the loop.
    
    logger.info("Starting mock RL fine-tuning loop...")
    # ti_generator.load_rl_checkpoint('path_to_your_rl_checkpoint.pt') # Optionally load
    ti_generator.train_rl_fine_tuning(num_iterations=config['num_rl_iterations'])

    logger.info("Script finished.")
    # You would then analyze ti_generator.results and saved checkpoints/structures
