In [1]:
import os
import json
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, get_scheduler
from datasets import load_dataset
from tqdm.auto import tqdm
from dotenv import load_dotenv
import torch.nn as nn
import torch.optim as optim
import plotly.express as px
import logging

In [2]:
import os
os.environ['HF_HOME'] = '/nas/ucb/satvik/hf_cache'

In [82]:
MODEL_NAME = "google/gemma-1.1-2b-it" # Using 2b version for potentially faster loading/iteration
DATASET_NAME = "XythicK/Chemistry"
DATASET_CONFIG = None  # Chemistry dataset doesn't have a config
# Using a very small slice for quick testing/demonstration
DATASET_SLICE_TRAIN = "train[:180]"
DATASET_SLICE_TEST = "train[180:420]"  # Held-out set from Chemistry dataset
NUM_EPOCHS = 10
BATCH_SIZE = 2 # Keep small for memory constraints
MAX_SEQ_LENGTH = 256 # Reduce sequence length for memory
LEARNING_RATE = 1e-5 # Adjusted learning rate
LOSS_FILE = "train_losses.txt"
TARGET_LAYER_INDEX = -2 # Second to last layer
MLORA_RANK = 256 # Rank for MLoRA matrices
CLUSTERABILITY_WEIGHT = 10.0 # Weight for clusterability term in loss function

In [4]:
def setup():
    """Load environment variables and set device."""
    load_dotenv()
    # hf_token = os.getenv("HF_TOKEN")
    hf_token = ""
    if not hf_token:
        logging.warning("HF_TOKEN environment variable not found. Ensure you are logged in via huggingface-cli login or set the HF_TOKEN.")
    # Use HF_TOKEN if available (required for Gemma models)
    # Note: transformers automatically uses HF_TOKEN env var if present

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")
    return device

In [5]:
class MLoRAAdapter(nn.Module):
    def __init__(self, input_dim, output_dim, rank=8):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.rank = rank
        
        logging.info(f"Initializing MLoRA with dims: input={input_dim}, output={output_dim}, rank={rank}")
        
        # MLoRA matrices
        self.A = nn.Linear(input_dim, rank, bias=False)
        self.B = nn.Linear(rank, rank, bias=False)
        self.C = nn.Linear(rank, output_dim, bias=False)
        
        # Initialize with small values
        nn.init.normal_(self.A.weight, std=0.01)
        nn.init.normal_(self.B.weight, std=0.01)
        nn.init.normal_(self.C.weight, std=0.01)
        
        # Scale factor to prevent output from dominating the residual
        self.scaling = 0.1
        
        # ReLU activations for H1 and H2
        self.activation = nn.ReLU()
    
    def forward(self, x):
        # Log shape for debugging
        batch_size = x.size(0)
        seq_len = x.size(1) if x.dim() > 2 else 1
        
        # Reshape if needed to handle different input formats
        if x.dim() > 2:
            # If x is [batch_size, seq_len, hidden_dim]
            orig_shape = x.shape
            x_reshaped = x.view(-1, self.input_dim)
        else:
            # If x is already [batch_size, hidden_dim]
            x_reshaped = x
            
        # MLoRA forward pass: x -> A -> H1 -> B -> H2 -> C -> out
        h1 = self.activation(self.A(x_reshaped))
        h2 = self.activation(self.B(h1))
        out = self.C(h2)
        
        # Apply scaling
        out = out * self.scaling
        
        # Reshape back to original shape if needed
        if x.dim() > 2:
            out = out.view(orig_shape)
            
        return out
    
    def get_B_matrix(self):
        # Return B matrix for clusterability computation
        return self.B.weight

In [6]:
class ChemistryDataset(Dataset):
    def __init__(self, texts, topics, subtopics, tokenizer, max_length):

        assert len(texts) == len(topics) == len(subtopics), "Texts, topics, and subtopics must have the same length."

        self.tokenizer = tokenizer
        self.max_length = max_length
        self.encodings = []
        
        logging.info("Tokenizing dataset...")
        skipped_count = 0
        
        for i, text in enumerate(tqdm(texts, desc="Tokenizing")):
            # Basic check for valid text (ensure it's a string and not too short)
            if text and isinstance(text, str) and len(text.strip()) > 10:
                try:
                    # Tokenize, pad, and truncate
                    tokenized = tokenizer(
                        text,
                        truncation=True,
                        max_length=self.max_length,
                        padding="max_length",
                        return_tensors="pt",
                    )
                    # Remove the batch dimension added by the tokenizer
                    input_ids = tokenized["input_ids"].squeeze(0)
                    attention_mask = tokenized["attention_mask"].squeeze(0)

                    # Create labels by cloning input_ids
                    # For Causal LM, labels are typically input_ids.
                    # Padding token ids in labels are replaced with -100 to be ignored in loss calculation.
                    labels = input_ids.clone()
                    labels[labels == self.tokenizer.pad_token_id] = -100

                    self.encodings.append({
                        "input_ids": input_ids,
                        "attention_mask": attention_mask,
                        "labels": labels,
                        "topic": topics[i],
                        "subtopic": subtopics[i]
                    })
                except Exception as e:
                    logging.warning(f"Skipping text due to tokenization error: {e}. Text snippet: {text[:100]}...")
                    skipped_count += 1
            else:
                # Log or count texts that are None, not strings, or too short
                if not (text and isinstance(text, str)):
                    logging.debug(f"Skipping non-string or empty text entry: {type(text)}")
                elif len(text.strip()) <= 10:
                    logging.debug(f"Skipping short text entry: {text[:30]}...")
                skipped_count += 1

        if skipped_count > 0:
            logging.warning(f"Skipped {skipped_count} invalid or short text entries.")
        if not self.encodings:
            raise ValueError("No valid data processed. Check dataset content, filtering, and 'TEXT' field extraction.")


    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        # Return the dictionary directly, DataLoader will batch items
        return self.encodings[idx]

In [7]:
data_hf = load_dataset(
    DATASET_NAME,
    None,
    split='train[180:2000]', 
    trust_remote_code=True
    )

In [8]:
set(data_hf['Topic'])

{'Inorganic Chemistry', 'Organic Chemistry', 'Physical Chemistry'}

In [9]:
set(data_hf['Subtopic'])

{'Alkenes', 'Kinetics', 'Nomenclature', 'Oxidation States'}

In [75]:
def load_and_prepare_data(tokenizer, dataset_slice, dataset_name_arg, dataset_config_arg, max_seq_length_arg, batch_size_arg):
    """Loads and preprocesses the dataset."""
    logging.info(f"Loading dataset: {dataset_name_arg}, Config: {dataset_config_arg}, Slice: {dataset_slice}...")
    try:
        # Use arguments for dataset name and config
        data_hf = load_dataset(dataset_name_arg, dataset_config_arg, split=dataset_slice, trust_remote_code=True)

        # Extract text data - ensure we get strings from the 'TEXT' field for the chemistry dataset
        # Also handle cases where 'TEXT' might be missing or not a string for robustness
        texts = []
        topics = []
        subtopics = []
        for item in data_hf:
            question_content = item.get('Question')
            answer_content = item.get("Answer_1")
            topic = item.get("Topic")
            subtopic = item.get("Subtopic")
            text_content = f"<start_of_turn>user\n{question_content}<end_of_turn>\n<start_of_turn>model\n{answer_content}<end_of_turn>"
            # text_content = question_content + answer_content
            if isinstance(text_content, str):
                texts.append(text_content)
                topics.append(topic)
                subtopics.append(subtopic)
            else:
                logging.debug(f"Found item without valid 'TEXT' field or non-string content: {item}")

        logging.info(f"Loaded {len(texts)} text documents from 'TEXT' field.")
        if not texts:
            raise ValueError("No text data found in the 'TEXT' field of the dataset slice.")
    except Exception as e:
        logging.error(f"Failed to load or process dataset: {e}")
        raise

    # Use argument for max_seq_length
    dataset = ChemistryDataset(texts, topics, subtopics, tokenizer, max_seq_length_arg)
    if len(dataset) == 0:
        raise ValueError("Dataset created, but contains no processable entries. Check data and tokenization.")

    # Shuffle training data, don't shuffle test data
    # This logic assumes specific slice names for train/test.
    # Consider a more robust way if slice names vary (e.g., pass a boolean is_train flag).
    is_train = "train[" in dataset_slice and ":180]" in dataset_slice # Based on your example slice
    # is_train = "train" in dataset_slice.lower() # A more general check

    dataloader = DataLoader(dataset, batch_size=batch_size_arg, shuffle=is_train) # Use argument for batch_size
    logging.info(f"Created DataLoader with {len(dataloader)} batches.")
    return dataloader

In [11]:
def load_model_and_tokenizer():
    """Loads the model and tokenizer."""
    hf_token = ""
    logging.info(f"Loading tokenizer: {MODEL_NAME}")
    # trust_remote_code=True might be needed for some models
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True, token = hf_token)

    # Set pad token if missing
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        logging.info("Set pad_token to eos_token")

    logging.info(f"Loading model: {MODEL_NAME}")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
        token = hf_token,
        cache_dir="/nas/ucb/satvik/hf_cache",
        # Use torch_dtype=torch.float16 or bfloat16 for memory efficiency if GPU supports it
        # torch_dtype=torch.bfloat16,
        # device_map="auto" # Can help distribute large models across GPUs/CPU
    )
    # Ensure model's pad token id matches tokenizer's
    model.config.pad_token_id = tokenizer.pad_token_id
    return model, tokenizer

In [12]:
def model_forward_with_mlora(model, mlora_adapter, input_ids, attention_mask, labels=None):
    """
    Custom forward pass that applies MLoRA adapter to the target MLP's output.
    """
    # Prepare inputs
    model_inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
    }
    
    # If labels provided, add them to inputs
    if labels is not None:
        model_inputs["labels"] = labels
    
    # Get original outputs - we need to handle the hidden states
    model_inputs["output_hidden_states"] = True
    outputs = model(**model_inputs)
    
    # We need to modify this to return modified logits
    return outputs

In [13]:
def register_mlora_hook(model, mlora_adapter, target_layer_index):
    """
    Registers a forward hook on the target MLP layer to apply MLoRA.
    """
    # For Gemma, the path to the MLP output is model.model.layers[index].mlp
    target_layer = model.model.layers[target_layer_index].mlp
    
    # Get shape information for debugging
    logging.info(f"Registering hook on layer: {target_layer}")
    
    # We need to track whether this is the first forward pass to log dimension info
    first_pass = [True]
    
    def mlora_hook(module, input, output):
        """
        Hook that applies MLoRA to the MLP output.
        - input: tuple containing tensor of shape [batch_size, seq_len, hidden_dim]
        - output: tensor of shape [batch_size, seq_len, hidden_dim]
        """
        if first_pass[0]:
            logging.info(f"Hook input shape: {input[0].shape}")
            logging.info(f"Hook output shape: {output.shape}")
            first_pass[0] = False
        
        try:
            # Apply MLoRA adapter to the input tensor
            mlora_output = mlora_adapter(input[0])
            
            # Handle shape mismatch if needed
            if mlora_output.shape != output.shape:
                logging.warning(f"Shape mismatch: MLoRA output {mlora_output.shape}, MLP output {output.shape}")
                # Reshape to match output dimensions
                mlora_output = mlora_output.view_as(output)
            
            # Add residual connection
            return output + mlora_output
            
        except RuntimeError as e:
            # If there's an error, log it but return the original output to prevent training failure
            logging.error(f"Error in MLoRA hook: {e}")
            return output
    
    # Register the hook to run after the MLP forward pass
    hook_handle = target_layer.register_forward_hook(mlora_hook)
    return hook_handle

In [26]:
def evaluate_model(model, mlora_adapter, dataloader, device):
    """Evaluates the model on the provided dataloader."""

    model.eval()  # Set model to evaluation mode
    total_loss = 0.0
    num_batches = 0
    
    # Initialize dictionaries to track losses per topic and subtopic
    topic_losses = {}
    subtopic_losses = {}
    topic_counts = {}
    subtopic_counts = {}
    
    with torch.no_grad():  # No gradient calculation during evaluation
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            # Move batch tensors to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            topics = batch['topic']
            subtopics = batch['subtopic']
            
            # Forward pass
            outputs = model_forward_with_mlora(
                model=model,
                mlora_adapter=mlora_adapter,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            # Calculate loss (CE only for evaluation)
            loss_fct = nn.CrossEntropyLoss()
            logits = outputs.logits
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
            
            # Update total loss
            total_loss += loss.item()
            num_batches += 1
            
            # Update topic and subtopic losses
            for topic, subtopic in zip(topics, subtopics):
                # Update topic losses
                if topic not in topic_losses:
                    topic_losses[topic] = 0.0
                    topic_counts[topic] = 0
                topic_losses[topic] += loss.item()
                topic_counts[topic] += 1
                
                # Update subtopic losses
                if subtopic not in subtopic_losses:
                    subtopic_losses[subtopic] = 0.0
                    subtopic_counts[subtopic] = 0
                subtopic_losses[subtopic] += loss.item()
                subtopic_counts[subtopic] += 1
    
    # Calculate average losses
    avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
    
    # Calculate average losses per topic and subtopic
    avg_topic_losses = {topic: loss / topic_counts[topic] for topic, loss in topic_losses.items()}
    avg_subtopic_losses = {subtopic: loss / subtopic_counts[subtopic] for subtopic, loss in subtopic_losses.items()}
    
    return avg_loss, avg_topic_losses, avg_subtopic_losses

In [15]:
device = setup()

In [16]:
device

device(type='cuda')

In [17]:
model, tokenizer = load_model_and_tokenizer()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
model.to(device)

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): GemmaRMSNorm((2048,),

In [83]:
train_dataloader = load_and_prepare_data(
            tokenizer, DATASET_SLICE_TRAIN, DATASET_NAME, DATASET_CONFIG, MAX_SEQ_LENGTH, BATCH_SIZE
        )
test_dataloader = load_and_prepare_data(
            tokenizer, DATASET_SLICE_TEST, DATASET_NAME, DATASET_CONFIG, MAX_SEQ_LENGTH, BATCH_SIZE
        )

Tokenizing:   0%|          | 0/180 [00:00<?, ?it/s]

Tokenizing:   0%|          | 0/240 [00:00<?, ?it/s]

In [22]:
model.config.hidden_size

2048

In [84]:
def intervention(mlora, cluster_index, num_clusters=4, style='OFF'):
    new_mlora = MLoRAAdapter(
        input_dim=mlora.input_dim,
        output_dim=mlora.output_dim,
        rank=mlora.rank
    )
    new_mlora.load_state_dict(mlora.state_dict())
    B_matrix = new_mlora.B.weight.data  # Shape: [rank, rank]
    cluster_size = B_matrix.shape[0] // num_clusters
    start_idx = cluster_index * cluster_size
    end_idx = (cluster_index + 1) * cluster_size
    
    if style == 'OFF':
        # Original behavior - zero out the specified cluster
        B_matrix[start_idx:end_idx, start_idx:end_idx] = 0
    else:  # style == 'ON'
        # Zero out all clusters except the specified one
        for i in range(num_clusters):
            curr_start = i * cluster_size
            curr_end = (i + 1) * cluster_size
            if i != cluster_index:
                B_matrix[curr_start:curr_end, curr_start:curr_end] = 0
    
    return new_mlora

In [85]:
mlora_adapter_path: str = 'results/mlora_adapter.pt'

input_dim = model.config.hidden_size
output_dim = model.config.hidden_size
safe_rank = min(MLORA_RANK, input_dim, output_dim)  

mlora_adapter = MLoRAAdapter(
            input_dim=input_dim,
            output_dim=output_dim,
            rank=safe_rank
    )
mlora_adapter.load_state_dict(torch.load(mlora_adapter_path))

imlora_adapter = intervention(mlora_adapter, cluster_index=0, num_clusters=4, style='OFF')

In [86]:
imlora_adapter.B.weight.shape

torch.Size([256, 256])

In [87]:
hook_handle = register_mlora_hook(model, imlora_adapter, TARGET_LAYER_INDEX)

imlora_adapter.to(device)

MLoRAAdapter(
  (A): Linear(in_features=2048, out_features=256, bias=False)
  (B): Linear(in_features=256, out_features=256, bias=False)
  (C): Linear(in_features=256, out_features=2048, bias=False)
  (activation): ReLU()
)

In [88]:
test_loss, topic_losses, subtopic_losses = evaluate_model(model, imlora_adapter, test_dataloader, device)

Evaluating:   0%|          | 0/120 [00:00<?, ?it/s]

In [89]:
test_loss

9.036648738384248

In [90]:
topic_losses

{'Inorganic Chemistry': 9.002590307548864,
 'Organic Chemistry': 9.636113495662295,
 'Physical Chemistry': 8.45674735446309}

In [91]:
subtopic_losses

{'Oxidation States': 9.002590307548864,
 'Alkenes': 9.919207445780437,
 'Nomenclature': 9.332798549107142,
 'Kinetics': 8.45674735446309}

In [98]:
def remove_all_hooks(model):
    """Remove all hooks from a model and its submodules."""
    # Remove hooks from each module
    for module in model.modules():
        module._forward_hooks.clear()
        module._forward_pre_hooks.clear()
        module._backward_hooks.clear()

In [99]:
results_dict = {
    'ON': [],
    'OFF': []
}

for style in ['ON', 'OFF']:
    for cluster_index in range(4):
        print(f"Evaluating style={style}, cluster_index={cluster_index}")
        remove_all_hooks(model)
        imlora_adapter = intervention(
            mlora_adapter, 
            cluster_index=cluster_index, 
            num_clusters=4, 
            style=style,
            )
        imlora_adapter.to(device)
        hook_handle = register_mlora_hook(model, imlora_adapter, TARGET_LAYER_INDEX)
        test_loss, topic_losses, subtopic_losses = evaluate_model(
            model, imlora_adapter, test_dataloader, device
        )
        results_dict[style].append({
            'topic_losses': topic_losses,
            'subtopic_losses': subtopic_losses
        })

Evaluating style=ON, cluster_index=0


Evaluating:   0%|          | 0/120 [00:00<?, ?it/s]

Evaluating style=ON, cluster_index=1


Evaluating:   0%|          | 0/120 [00:00<?, ?it/s]

Evaluating style=ON, cluster_index=2


Evaluating:   0%|          | 0/120 [00:00<?, ?it/s]

Evaluating style=ON, cluster_index=3


Evaluating:   0%|          | 0/120 [00:00<?, ?it/s]

Evaluating style=OFF, cluster_index=0


Evaluating:   0%|          | 0/120 [00:00<?, ?it/s]

Evaluating style=OFF, cluster_index=1


Evaluating:   0%|          | 0/120 [00:00<?, ?it/s]

Evaluating style=OFF, cluster_index=2


Evaluating:   0%|          | 0/120 [00:00<?, ?it/s]

Evaluating style=OFF, cluster_index=3


Evaluating:   0%|          | 0/120 [00:00<?, ?it/s]

In [100]:
results_dict

{'ON': [{'topic_losses': {'Inorganic Chemistry': 5.2309326769700695,
    'Organic Chemistry': 5.787021308109678,
    'Physical Chemistry': 4.535642568455186},
   'subtopic_losses': {'Oxidation States': 5.2309326769700695,
    'Alkenes': 5.995877891116672,
    'Nomenclature': 5.563246397745042,
    'Kinetics': 4.535642568455186}},
  {'topic_losses': {'Inorganic Chemistry': 5.202829268441271,
    'Organic Chemistry': 5.756752967834473,
    'Physical Chemistry': 4.490657379460889},
   'subtopic_losses': {'Oxidation States': 5.202829268441271,
    'Alkenes': 5.963387976752387,
    'Nomenclature': 5.535358315422421,
    'Kinetics': 4.490657379460889}},
  {'topic_losses': {'Inorganic Chemistry': 5.247637093956791,
    'Organic Chemistry': 5.79934916002997,
    'Physical Chemistry': 4.5467608085898465},
   'subtopic_losses': {'Oxidation States': 5.247637093956791,
    'Alkenes': 6.00380441877577,
    'Nomenclature': 5.580289954230899,
    'Kinetics': 4.5467608085898465}},
  {'topic_losses': {

In [105]:
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

data = results_dict

# Create subplots: 2 rows, 2 columns
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=('ON - Topics', 'OFF - Topics', 'ON - Subtopics', 'OFF - Subtopics'),
    vertical_spacing=0.15,
    horizontal_spacing=0.1
)

# Elegant dark colors
colors = ['#1ABC9C', '#E74C3C', '#3498DB', '#F1C40F']

# Plot ON Topics (top left)
for i, cluster_data in enumerate(data['ON']):
    topics = list(cluster_data['topic_losses'].keys())
    values = [round(val, 3) for val in cluster_data['topic_losses'].values()]
    
    fig.add_trace(
        go.Bar(
            x=topics,
            y=values,
            name=f'Cluster {i}',
            marker_color=colors[i],
            showlegend=(i == 0)  # Only show legend for first set
        ),
        row=1, col=1
    )

# Plot OFF Topics (top right)
for i, cluster_data in enumerate(data['OFF']):
    topics = list(cluster_data['topic_losses'].keys())
    values = [round(val, 3) for val in cluster_data['topic_losses'].values()]
    
    fig.add_trace(
        go.Bar(
            x=topics,
            y=values,
            name=f'Cluster {i}',
            marker_color=colors[i],
            showlegend=False
        ),
        row=1, col=2
    )

# Plot ON Subtopics (bottom left)
for i, cluster_data in enumerate(data['ON']):
    subtopics = list(cluster_data['subtopic_losses'].keys())
    values = [round(val, 3) for val in cluster_data['subtopic_losses'].values()]
    
    fig.add_trace(
        go.Bar(
            x=subtopics,
            y=values,
            name=f'Cluster {i}',
            marker_color=colors[i],
            showlegend=False
        ),
        row=2, col=1
    )

# Plot OFF Subtopics (bottom right)
for i, cluster_data in enumerate(data['OFF']):
    subtopics = list(cluster_data['subtopic_losses'].keys())
    values = [round(val, 3) for val in cluster_data['subtopic_losses'].values()]
    
    fig.add_trace(
        go.Bar(
            x=subtopics,
            y=values,
            name=f'Cluster {i}',
            marker_color=colors[i],
            showlegend=False
        ),
        row=2, col=2
    )

# Update layout
fig.update_layout(
    title='Cluster Interventions on MLoRA: ON and OFF',
    height=800,
    width=1200,
    barmode='group',
    font=dict(
        family="Computer Modern",
        size=16
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1,
        font=dict(size=16)
    )
)

# Update y-axes labels
fig.update_yaxes(
    title_text="Loss",
    row=1, col=1,
    gridcolor='lightgray',
    zerolinecolor='lightgray',
    title_font=dict(size=16),
    tickfont=dict(size=16)
)
fig.update_yaxes(
    title_text="Loss",
    row=1, col=2,
    gridcolor='lightgray',
    zerolinecolor='lightgray',
    title_font=dict(size=16),
    tickfont=dict(size=16)
)
fig.update_yaxes(
    title_text="Loss",
    row=2, col=1,
    gridcolor='lightgray',
    zerolinecolor='lightgray',
    title_font=dict(size=16),
    tickfont=dict(size=16)
)
fig.update_yaxes(
    title_text="Loss",
    row=2, col=2,
    gridcolor='lightgray',
    zerolinecolor='lightgray',
    title_font=dict(size=16),
    tickfont=dict(size=16)
)

# Update x-axes
for i in range(1, 3):
    for j in range(1, 3):
        fig.update_xaxes(
            tickangle=45,
            gridcolor='lightgray',
            zerolinecolor='lightgray',
            row=i, col=j,
            tickfont=dict(size=16)
        )

# Update subplot titles
fig.update_annotations(font_size=16)

fig.update_layout(
    width=1200,
    height=800,
)

# Show the plot
fig.show()

In [106]:
results_dict

{'ON': [{'topic_losses': {'Inorganic Chemistry': 5.2309326769700695,
    'Organic Chemistry': 5.787021308109678,
    'Physical Chemistry': 4.535642568455186},
   'subtopic_losses': {'Oxidation States': 5.2309326769700695,
    'Alkenes': 5.995877891116672,
    'Nomenclature': 5.563246397745042,
    'Kinetics': 4.535642568455186}},
  {'topic_losses': {'Inorganic Chemistry': 5.202829268441271,
    'Organic Chemistry': 5.756752967834473,
    'Physical Chemistry': 4.490657379460889},
   'subtopic_losses': {'Oxidation States': 5.202829268441271,
    'Alkenes': 5.963387976752387,
    'Nomenclature': 5.535358315422421,
    'Kinetics': 4.490657379460889}},
  {'topic_losses': {'Inorganic Chemistry': 5.247637093956791,
    'Organic Chemistry': 5.79934916002997,
    'Physical Chemistry': 4.5467608085898465},
   'subtopic_losses': {'Oxidation States': 5.247637093956791,
    'Alkenes': 6.00380441877577,
    'Nomenclature': 5.580289954230899,
    'Kinetics': 4.5467608085898465}},
  {'topic_losses': {

In [107]:
with open('results/interventions_results_chemistry.json', 'w') as f:
    json.dump(results_dict, f, indent=4)