In [20]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import re
from typing import List, Dict, Tuple

class AutoregressiveCitationMatcher(nn.Module):
    def __init__(
        self,
        model_name: str = "gpt2",
        max_length: int = 512,
        cite_token: str = "<CITE>",
        ref_token: str = "<REF>",
        device: torch.device = None
    ):
        super().__init__()
        
        # Set device
        self.device = device if device is not None else torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Add special tokens for citation and reference
        special_tokens = {
            'additional_special_tokens': [cite_token, ref_token]
        }
        self.tokenizer.add_special_tokens(special_tokens)
        
        # For GPT models, we might need to set pad token if it's not set
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Initialize encoder for both source and target texts
        # We use the same model for both since autoregressive models maintain context
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
        
        # Resize token embeddings
        self.model.resize_token_embeddings(len(self.tokenizer))
        
        self.cite_token = cite_token
        self.ref_token = ref_token
        self.max_length = max_length
        
        hidden_size = self.model.config.hidden_size
        # We'll use these to project the [REF] embeddings from both source and target
        self.source_projector = nn.Linear(hidden_size, hidden_size).to(self.device)
        self.target_projector = nn.Linear(hidden_size, hidden_size).to(self.device)
        
        print(f"Using device: {self.device}")

    def prepare_source_context(
        self,
        text: str,
        target_citation: str,
        window_size: int = 100
    ) -> str:
        """Prepare source context with citation and reference tokens."""
        citation_pattern = re.escape(f"[[{target_citation}]]")
        # Add both [CITE] and [REF] tokens
        modified_text = re.sub(citation_pattern, f"{self.cite_token}", text)
        modified_text = f"{modified_text} {self.ref_token}"
        
        cite_pos = modified_text.find(self.cite_token)
        if cite_pos == -1:
            raise ValueError("Citation token not found in text")
            
        # Include full context up to the [REF] token
        start = max(0, cite_pos - window_size)
        return modified_text[start:]

    def prepare_target_page(
        self,
        page_content: str,
        max_summary_length: int = 200
    ) -> str:
        """Prepare target page with reference token."""
        # Take first paragraph as summary
        first_para = page_content.split('\n\n')[0].strip()
        if len(first_para) > max_summary_length:
            first_para = first_para[:max_summary_length] + "..."
            
        return f"{first_para} {self.ref_token}"

    def get_token_embedding(
        self,
        hidden_states: torch.Tensor,
        input_ids: torch.Tensor,
        token_id: int
    ) -> torch.Tensor:
        """Extract embedding for specific token."""
        # Get token positions (batch_size, num_occurrences)
        token_positions = (input_ids == token_id).nonzero()
        
        if len(token_positions) == 0:
            raise ValueError(f"Token id {token_id} not found in sequence")
        
        # Extract embeddings at token positions
        token_embeddings = hidden_states[
            token_positions[:, 0],
            token_positions[:, 1]
        ]
        
        # If multiple occurrences, raise an error
        if len(token_positions) > 1:
            raise ValueError(f"There should only be one occurance of REF token")
            
        return token_embeddings

    def encode_text(
        self,
        text: str,
        is_source: bool = True
    ) -> torch.Tensor:
        """Encode text and extract reference token embedding."""
        inputs = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        ).to(self.device)
        
        # Get hidden states from the model
        outputs = self.model(
            **inputs,
            output_hidden_states=True,
            return_dict=True
        )
        
        # Get [REF] token embedding from the last hidden state
        ref_token_id = self.tokenizer.convert_tokens_to_ids(self.ref_token)
        ref_embeddings = self.get_token_embedding(
            outputs.hidden_states[-1],
            inputs['input_ids'],
            ref_token_id
        )
        return ref_embeddings
        
        # Project the embeddings based on whether it's source or target
    
        # if is_source:
        #     return self.source_projector(ref_embeddings)
        # else:
        #     return self.target_projector(ref_embeddings)

    def forward(
        self,
        source_contexts: List[str],
        target_pages: List[str],
        temperature: float = 0.07
    ) -> torch.Tensor:
        """Compute similarity between source and target [REF] embeddings."""
        # Encode all contexts and references
        source_embeddings = []
        target_embeddings = []
        
        for context in source_contexts:
            source_emb = self.encode_text(context, is_source=True)
            source_embeddings.append(source_emb)
            
        for page in target_pages:
            target_emb = self.encode_text(page, is_source=False)
            target_embeddings.append(target_emb)
        
        # Stack embeddings
        source_embeddings = torch.cat(source_embeddings, dim=0)
        target_embeddings = torch.cat(target_embeddings, dim=0)
        
        # Normalize embeddings
        source_embeddings = nn.functional.normalize(source_embeddings, dim=-1)
        target_embeddings = nn.functional.normalize(target_embeddings, dim=-1)
        
        # Compute similarity matrix
        similarity = torch.matmul(
            source_embeddings,
            target_embeddings.transpose(0, 1)
        ) / temperature
        
        return similarity

    def train_step(
        self,
        source_contexts: List[str],
        target_pages: List[str],
        optimizer: torch.optim.Optimizer,
        temperature: float = 0.07
    ) -> float:
        """Perform one training step."""
        optimizer.zero_grad()
        
        # Forward pass
        similarity = self(source_contexts, target_pages, temperature)
        
        # The diagonal elements should be the positive pairs
        labels = torch.arange(len(source_contexts)).to(self.device)
        # print('#'*20)
        # print('similarity = ', similarity, ', labels = ', labels)
        
        # Compute loss
        loss = nn.CrossEntropyLoss()(similarity, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        return loss.item()

def main():
    # Example data
    source_text = """A '''mall''' or '''shopping center''' is a large [[building]] that is full of many smaller [[shop]]s."""
    
    target_pages = {
        "building": """A building is a human-made structure with a roof and walls. 
        Buildings come in many sizes and shapes and have been adapted to a wide range of functions.
        They can be used for housing, commercial, educational, or industrial activities.""",
        
        "shop": """A shop, also known as a store or retail establishment, is a business premises 
        that sells goods directly to customers. Shops can be independent small businesses or 
        part of larger retail chains."""
    }
    
    # Initialize model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoregressiveCitationMatcher(device=device)
    
    # Prepare inputs for first citation
    source_context = model.prepare_source_context(source_text, "building")
    target_content = model.prepare_target_page(target_pages["building"])
    
    print("Source context with [CITE] and [REF]:")
    print(source_context)
    print("\nTarget page with [REF]:")
    print(target_content)
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    # Example training step
    print("\nPerforming example training step...")
    with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
        loss = model.train_step(
            [source_context],
            [target_content],
            optimizer
        )
    print(f"Training loss: {loss:.4f}")
    
    # Compute similarity
    print("\nComputing similarity...")
    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            similarity = model([source_context], [target_content])
    print(f"Similarity score: {similarity[0][0].item():.4f}")

if __name__ == "__main__":
    main()

Using device: cuda
Source context with [CITE] and [REF]:
A '''mall''' or '''shopping center''' is a large <CITE> that is full of many smaller [[shop]]s. <REF>

Target page with [REF]:
A building is a human-made structure with a roof and walls. 
        Buildings come in many sizes and shapes and have been adapted to a wide range of functions.
        They can be used for housing, c... <REF>

Performing example training step...
####################
similarity =  tensor([[14.2422]], device='cuda:0', dtype=torch.float16,
       grad_fn=<DivBackward0>) , labels =  tensor([0], device='cuda:0')
Training loss: 0.0000

Computing similarity...
Similarity score: 14.2422


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


In [18]:
hidden_states = torch.randn(2, 10, 768)  # [batch_size, seq_len, hidden_size]
input_ids = torch.tensor([
    [1, 2, 3, 100, 5, 100, 7, 8, 9, 10],  # batch 1, token_id 100 appears twice
    [1, 2, 3, 4, 5, 6, 7, 100, 9, 10]     # batch 2, token_id 100 appears once
])
token_id = 100
token_positions = (input_ids == token_id).nonzero()
print('token_positions = ', token_positions)

token_embeddings = hidden_states[
    token_positions[:, 0],  # batch indices
    token_positions[:, 1]   # sequence positions
]

token_embeddings

token_positions =  tensor([[0, 3],
        [0, 5],
        [1, 7]])


tensor([[-1.3096,  0.5539,  0.1889,  ...,  1.3846,  0.5829,  0.5157],
        [-0.0750, -0.3833, -1.4829,  ..., -0.4143,  1.0510, -0.6311],
        [-0.5740,  0.5425,  0.4611,  ...,  0.3298,  0.3939, -0.5214]])

In [3]:
import torch
import torch.nn as nn
import random
import numpy as np
from typing import List, Dict, Tuple
import re

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import re
from typing import List, Dict, Tuple

class AutoregressiveCitationMatcher(nn.Module):
    def __init__(
        self,
        model_name: str = "gpt2",
        max_length: int = 512,
        cite_token: str = "<CITE>",
        ref_token: str = "<REF>",
        device: torch.device = None
    ):
        super().__init__()
        
        # Set device
        self.device = device if device is not None else torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Add special tokens for citation and reference
        special_tokens = {
            'additional_special_tokens': [cite_token, ref_token]
        }
        self.tokenizer.add_special_tokens(special_tokens)
        
        # For GPT models, we might need to set pad token if it's not set
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Initialize encoder for both source and target texts
        # We use the same model for both since autoregressive models maintain context
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
        
        # Resize token embeddings
        self.model.resize_token_embeddings(len(self.tokenizer))
        
        self.cite_token = cite_token
        self.ref_token = ref_token
        self.max_length = max_length
        
        hidden_size = self.model.config.hidden_size
        # We'll use these to project the [REF] embeddings from both source and target
        self.source_projector = nn.Linear(hidden_size, hidden_size).to(self.device)
        self.target_projector = nn.Linear(hidden_size, hidden_size).to(self.device)
        
        print(f"Using device: {self.device}")

    def prepare_source_context(
        self,
        text: str,
        target_citation: str,
        window_size: int = 100
    ) -> str:
        """Prepare source context with citation and reference tokens."""
        citation_pattern = re.escape(f"[[{target_citation}]]")
        # Add both [CITE] and [REF] tokens
        modified_text = re.sub(citation_pattern, f"{self.cite_token}", text)
        modified_text = f"{modified_text} {self.ref_token}"
        
        cite_pos = modified_text.find(self.cite_token)
        if cite_pos == -1:
            raise ValueError("Citation token not found in text")
            
        # Include full context up to the [REF] token
        start = max(0, cite_pos - window_size)
        return modified_text[start:]

    def prepare_target_page(
        self,
        page_content: str,
        max_summary_length: int = 200
    ) -> str:
        """Prepare target page with reference token."""
        # Take first paragraph as summary
        first_para = page_content.split('\n\n')[0].strip()
        if len(first_para) > max_summary_length:
            first_para = first_para[:max_summary_length] + "..."
            
        return f"{first_para} {self.ref_token}"

    def get_token_embedding(
        self,
        hidden_states: torch.Tensor,
        input_ids: torch.Tensor,
        token_id: int
    ) -> torch.Tensor:
        """Extract embedding for specific token."""
        # Get token positions (batch_size, num_occurrences)
        token_positions = (input_ids == token_id).nonzero()
        
        if len(token_positions) == 0:
            raise ValueError(f"Token id {token_id} not found in sequence")
        
        # Extract embeddings at token positions
        token_embeddings = hidden_states[
            token_positions[:, 0],
            token_positions[:, 1]
        ]
        
        # If multiple occurrences, raise an error
        if len(token_positions) > 1:
            raise ValueError(f"There should only be one occurance of REF token")
            
        return token_embeddings

    def encode_text(
        self,
        text: str,
        is_source: bool = True
    ) -> torch.Tensor:
        """Encode text and extract reference token embedding."""
        inputs = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        ).to(self.device)
        
        # Get hidden states from the model
        outputs = self.model(
            **inputs,
            output_hidden_states=True,
            return_dict=True
        )
        
        # Get [REF] token embedding from the last hidden state
        ref_token_id = self.tokenizer.convert_tokens_to_ids(self.ref_token)
        ref_embeddings = self.get_token_embedding(
            outputs.hidden_states[-1],
            inputs['input_ids'],
            ref_token_id
        )
        return ref_embeddings
        
        # Project the embeddings based on whether it's source or target
    
        # if is_source:
        #     return self.source_projector(ref_embeddings)
        # else:
        #     return self.target_projector(ref_embeddings)

    def forward(
        self,
        source_contexts: List[str],
        target_pages: List[str],
        temperature: float = 0.07
    ) -> torch.Tensor:
        """Compute similarity between source and target [REF] embeddings."""
        # Encode all contexts and references
        source_embeddings = []
        target_embeddings = []
        
        for context in source_contexts:
            source_emb = self.encode_text(context, is_source=True)
            source_embeddings.append(source_emb)
            
        for page in target_pages:
            target_emb = self.encode_text(page, is_source=False)
            target_embeddings.append(target_emb)
        
        # Stack embeddings
        source_embeddings = torch.cat(source_embeddings, dim=0)
        target_embeddings = torch.cat(target_embeddings, dim=0)
        
        # Normalize embeddings
        source_embeddings = nn.functional.normalize(source_embeddings, dim=-1)
        target_embeddings = nn.functional.normalize(target_embeddings, dim=-1)
        
        # Compute similarity matrix
        similarity = torch.matmul(
            source_embeddings,
            target_embeddings.transpose(0, 1)
        ) / temperature
        
        return similarity

    def train_step(
        self,
        source_contexts: List[str],
        target_pages: List[str],
        optimizer: torch.optim.Optimizer,
        temperature: float = 0.07
    ) -> float:
        """Perform one training step."""
        optimizer.zero_grad()
        
        # Forward pass
        similarity = self(source_contexts, target_pages, temperature)
        
        # The diagonal elements should be the positive pairs
        labels = torch.arange(len(source_contexts)).to(self.device)
        # print('#'*20)
        # print('similarity = ', similarity, ', labels = ', labels)
        
        # Compute loss
        loss = nn.CrossEntropyLoss()(similarity, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        return loss.item()

# Dictionary of source pages (articles) with their content
source_pages = {
    "Machine_Learning": """'''Machine learning''' is a subfield of [[artificial intelligence]] that focuses on developing systems that can learn from [[data]].
    Traditional machine learning approaches include [[supervised learning]], [[unsupervised learning]], and [[reinforcement learning]].
    Modern applications often use [[deep learning]] architectures trained on large datasets.
    The field has seen significant advances since the development of [[neural network|neural networks]].""",
    
    "Climate_Change": """'''Climate change''' refers to long-term alterations in [[Earth]]'s climate systems.
    Major factors include [[greenhouse gas|greenhouse gases]] and [[carbon dioxide]] emissions.
    Scientists study [[glacier|glaciers]] and [[ice sheet|ice sheets]] to understand historical climate patterns.
    [[Renewable energy]] technologies are crucial for mitigating climate change impacts.
    Changes in [[weather pattern|weather patterns]] and [[sea level|sea levels]] are key indicators.""",
    
    "Internet": """The '''Internet''' is a global network that revolutionized [[communication]].
    It operates using protocols like [[TCP/IP]] and enables various forms of [[e-commerce]].
    [[World Wide Web|Web]] technologies have evolved from simple [[HTML]] to complex [[web application|web applications]].
    Modern internet infrastructure relies heavily on [[cloud computing]] and [[data center|data centers]].""",
    
    "Solar_System": """The '''Solar System''' consists of the [[Sun]] and celestial bodies bound by its gravity.
    It includes eight [[planet|planets]], numerous [[moon|moons]], and countless [[asteroid|asteroids]].
    The [[asteroid belt]] lies between [[Mars]] and [[Jupiter]].
    [[Comet|Comets]] and [[meteor|meteors]] are other notable objects in our cosmic neighborhood.""",
    
    "Human_Brain": """The '''human brain''' is the central [[organ]] of the [[nervous system]].
    It consists of regions like the [[cerebral cortex]] and [[hippocampus]].
    [[Neuron|Neurons]] communicate through [[synapse|synapses]] using [[neurotransmitter|neurotransmitters]].
    Modern [[neuroimaging]] techniques like [[MRI]] help study brain structure and function.""",
    
    "Artificial_Intelligence": """'''Artificial intelligence''' encompasses various approaches to creating intelligent systems.
    Key subfields include [[machine learning]] and [[natural language processing]].
    [[Computer vision]] systems can now perform complex visual tasks.
    [[Expert system|Expert systems]] and [[robotics]] demonstrate practical AI applications.
    Recent advances in [[deep learning]] have revolutionized the field.""",
}

# Dictionary of target pages
target_pages = {
    "artificial intelligence": """Artificial intelligence (AI) is intelligence demonstrated by machines, 
    as opposed to natural intelligence displayed by animals including humans. AI systems can perform 
    tasks that typically require human intelligence.""",
    
    "data": """Data are individual facts, statistics, or items of information, often numeric. In 
    computing, data represents information that can be processed, stored, or transmitted by computers.""",
    
    "supervised learning": """Supervised learning is a machine learning approach where the model learns 
    from labeled training data to make predictions on new, unseen data. It's widely used in 
    classification and regression tasks.""",
    
    "Earth": """Earth is the third planet from the Sun and the only astronomical object known to 
    harbor life. It's atmosphere and magnetic field protect life from harmful solar radiation.""",
    
    "carbon dioxide": """Carbon dioxide (CO2) is a greenhouse gas that plays a vital role in Earth's 
    carbon cycle. Increased atmospheric CO2 from human activities is a major driver of climate change.""",
    
    "TCP/IP": """TCP/IP (Transmission Control Protocol/Internet Protocol) is the foundational 
    communication protocol of the Internet. It specifies how data should be packetized, addressed, 
    transmitted, routed, and received.""",
    
    "Sun": """The Sun is the star at the center of the Solar System. It is a nearly perfect sphere 
    of hot plasma, heated to incandescence by nuclear fusion reactions in its core.""",
    
    "Mars": """Mars is the fourth planet from the Sun. Often called the Red Planet, it has a thin 
    atmosphere and features like valleys, deserts, and polar ice caps. It's a major target for 
    space exploration.""",
    
    "cerebral cortex": """The cerebral cortex is the outer layer of neural tissue of the cerebrum 
    of the brain. It plays a key role in memory, attention, perception, awareness, thought, 
    language, and consciousness.""",
    
    "machine learning": """Machine learning is a field of artificial intelligence that uses 
    statistical techniques to give computer systems the ability to 'learn' from data, without 
    being explicitly programmed."""
    # ... add more target pages as needed
}

def extract_citations(text: str) -> List[str]:
    """Extract all citations from a text and clean them."""
    citations = re.findall(r'\[\[(.*?)\]\]', text)
    # Clean up citations (take first part if pipe character exists)
    return [c.split('|')[0].lower() for c in citations]

def create_training_batch(
    model,
    batch_size: int,
    source_pages: Dict[str, str],
    target_pages: Dict[str, str]
) -> Tuple[List[str], List[str]]:
    """
    Create a training batch by randomly sampling source pages and their citations.
    
    Args:
        model: The citation matcher model
        batch_size: Number of examples in the batch
        source_pages: Dictionary of source page titles and their content
        target_pages: Dictionary of target page titles and their content
    
    Returns:
        source_contexts: List of prepared source contexts
        target_contents: List of prepared target contents
    """
    source_contexts = []
    target_contents = []
    
    # Get random permutation of source page keys
    source_keys = list(source_pages.keys())
    random.shuffle(source_keys)
    
    # Try to fill the batch
    attempts = 0
    max_attempts = batch_size * 3  # Allow some extra attempts to fill the batch
    
    while len(source_contexts) < batch_size and attempts < max_attempts:
        # Get next source page
        source_key = source_keys[attempts % len(source_keys)]
        source_text = source_pages[source_key]
        
        # Get all citations from this source page
        citations = extract_citations(source_text)
        
        # Filter to citations that have matching target pages
        valid_citations = [c for c in citations if c in target_pages]
        
        if valid_citations:
            # Randomly select one citation
            citation = random.choice(valid_citations)
            
            try:
                # Prepare source context and target content
                source_context = model.prepare_source_context(source_text, citation)
                target_content = model.prepare_target_page(target_pages[citation])
                
                source_contexts.append(source_context)
                target_contents.append(target_content)
                
            except ValueError as e:
                print(f"Error processing citation {citation} from {source_key}: {str(e)}")
        
        attempts += 1
    
    if len(source_contexts) < batch_size:
        print(f"Warning: Could only create {len(source_contexts)} examples for requested batch size {batch_size}")
    
    return source_contexts, target_contents

def main():
    # Initialize model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoregressiveCitationMatcher(device=device)
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    # Training parameters
    num_epochs = 3
    batch_size = 4
    steps_per_epoch = 10
    
    print("Starting training loop...")
    
    for epoch in range(num_epochs):
        total_loss = 0
        
        for step in range(steps_per_epoch):
            # Create training batch
            source_contexts, target_contents = create_training_batch(
                model,
                batch_size,
                source_pages,
                target_pages
            )
            
            if not source_contexts or not target_contents:
                print("Warning: Empty batch, skipping step")
                continue
            
            # Perform training step
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                loss = model.train_step(
                    source_contexts,
                    target_contents,
                    optimizer
                )
            
            total_loss += loss
            print(f"Epoch {epoch + 1}, Step {step + 1}/{steps_per_epoch}, Loss: {loss:.4f}")
        
        avg_loss = total_loss / steps_per_epoch
        print(f"Epoch {epoch + 1} complete. Average loss: {avg_loss:.4f}")

    # Example of model usage after training
    print("\nTesting model with a sample prediction...")
    source_text = source_pages["Machine_Learning"]
    citation = "artificial intelligence"
    
    source_context = model.prepare_source_context(source_text, citation)
    target_content = model.prepare_target_page(target_pages[citation])
    
    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            similarity = model([source_context], [target_content])
    print(f"Similarity score: {similarity[0][0].item():.4f}")

if __name__ == "__main__":
    main()



Using device: cuda
Starting training loop...
Epoch 1, Step 1/10, Loss: 1.3787


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Epoch 1, Step 2/10, Loss: 1.3484
Epoch 1, Step 3/10, Loss: 1.2130
Epoch 1, Step 4/10, Loss: 0.8815
Epoch 1, Step 5/10, Loss: 0.8069
Epoch 1, Step 6/10, Loss: 0.7612
Epoch 1, Step 7/10, Loss: 0.6135
Epoch 1, Step 8/10, Loss: 8.0613
Epoch 1, Step 9/10, Loss: 4.4313
Epoch 1, Step 10/10, Loss: 4.6977
Epoch 1 complete. Average loss: 2.4193
Epoch 2, Step 1/10, Loss: 3.4214
Epoch 2, Step 2/10, Loss: 3.4023
Epoch 2, Step 3/10, Loss: 1.8708
Epoch 2, Step 4/10, Loss: 0.7561
Epoch 2, Step 5/10, Loss: 0.7629
Epoch 2, Step 6/10, Loss: 0.7691
Epoch 2, Step 7/10, Loss: 0.7720
Epoch 2, Step 8/10, Loss: 0.7801
Epoch 2, Step 9/10, Loss: 0.7840
Epoch 2, Step 10/10, Loss: 0.7914
Epoch 2 complete. Average loss: 1.4110
Epoch 3, Step 1/10, Loss: 0.8025
Epoch 3, Step 2/10, Loss: 0.8111
Epoch 3, Step 3/10, Loss: 0.8156
Epoch 3, Step 4/10, Loss: 0.8202
Epoch 3, Step 5/10, Loss: 0.8271
Epoch 3, Step 6/10, Loss: 0.8326
Epoch 3, Step 7/10, Loss: 0.8344
Epoch 3, Step 8/10, Loss: 0.8377
Epoch 3, Step 9/10, Loss: 0.8

  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
