In [None]:
import torch
import torch.nn as nn
import random
import numpy as np
from typing import List, Dict, Tuple
import re
from transformers import AutoTokenizer, AutoModel,AutoModelForCausalLM

import xml.etree.ElementTree as ET
import bz2
import json
import sqlite3
from pathlib import Path
from typing import Iterator, Dict, Union, Optional
import re
import os 
import tqdm

from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict

class WikiDumpParser:
    def __init__(self, dump_path: str):
        self.dump_path = dump_path
        self.ns = {'mw': 'http://www.mediawiki.org/xml/export-0.10/'}

    def _open_dump(self) -> Iterator:
        if self.dump_path.endswith('.bz2'):
            return bz2.BZ2File(self.dump_path)
        return open(self.dump_path, 'rb')

    def iter_pages(self) -> Iterator[Dict]:
        context = ET.iterparse(self._open_dump(), events=('end',))
        
        for event, elem in context:
            if elem.tag.endswith('page'):
                title_elem = elem.find('.//mw:title', self.ns)
                revision = elem.find('.//mw:revision', self.ns)
                
                if revision is not None:
                    text_elem = revision.find('mw:text', self.ns)
                    timestamp_elem = revision.find('mw:timestamp', self.ns)
                    text = text_elem.text if text_elem is not None else ''
                    timestamp = timestamp_elem.text if timestamp_elem is not None else ''
                else:
                    text = ''
                    timestamp = ''
                
                title = title_elem.text if title_elem is not None else ''
                
                yield {
                    'title': title,
                    'text': text,
                    'timestamp': timestamp,
                    'is_redirect': bool(re.match(r'#REDIRECT', text or '', re.IGNORECASE))
                }
                
                elem.clear()

    def iter_articles(self, skip_redirects: bool = True) -> Iterator[Dict]:
        for page in self.iter_pages():
            title = page['title']
            
            if any(title.startswith(prefix) for prefix in [
                'Wikipedia:', 'Template:', 'Category:', 'Portal:',
                'File:', 'MediaWiki:', 'Help:', 'Book:', 'Draft:',
                'TimedText:', 'Module:', 'Special:'
            ]):
                continue
                
            if skip_redirects and page['is_redirect']:
                continue
                
            yield page

    def save_to_jsonl(self, output_path: Union[str, Path], sample_size: Optional[int] = None) -> int:
        """
        Save articles to a JSONL file (one JSON object per line).
        
        Args:
            output_path: Path to the output JSONL file
            sample_size: Optional number of articles to save
            
        Returns:
            int: Number of articles saved
        """
        count = 0
        with open(output_path, 'w', encoding='utf-8') as f:
            for i, article in enumerate(self.iter_articles()):
                if sample_size is not None and i >= sample_size:
                    break
                json.dump(article, f, ensure_ascii=False)
                f.write('\n')
                count += 1
        return count

    def save_to_sqlite(self, db_path: Union[str, Path], sample_size: Optional[int] = None,
                      batch_size: int = 1000) -> int:
        """
        Save articles to a SQLite database.
        
        Args:
            db_path: Path to the SQLite database file
            sample_size: Optional number of articles to save
            batch_size: Number of articles to insert in each batch
            
        Returns:
            int: Number of articles saved
        """
        conn = sqlite3.connect(db_path)
        c = conn.cursor()
        
        # Create table if it doesn't exist
        c.execute('''CREATE TABLE IF NOT EXISTS articles
                    (title TEXT PRIMARY KEY,
                     text TEXT,
                     timestamp TEXT,
                     is_redirect INTEGER)''')
        
        # Create index on title
        c.execute('CREATE INDEX IF NOT EXISTS idx_title ON articles(title)')
        
        count = 0
        batch = []
        
        try:
            for i, article in enumerate(self.iter_articles()):
                if sample_size is not None and i >= sample_size:
                    break
                    
                batch.append((
                    article['title'],
                    article['text'],
                    article['timestamp'],
                    1 if article['is_redirect'] else 0
                ))
                
                if len(batch) >= batch_size:
                    c.executemany(
                        'INSERT OR REPLACE INTO articles VALUES (?, ?, ?, ?)',
                        batch
                    )
                    conn.commit()
                    count += len(batch)
                    batch = []
            
            # Insert remaining articles
            if batch:
                c.executemany(
                    'INSERT OR REPLACE INTO articles VALUES (?, ?, ?, ?)',
                    batch
                )
                conn.commit()
                count += len(batch)
                
        finally:
            conn.close()
            
        return count

def read_from_jsonl(file_path: Union[str, Path], limit: Optional[int] = None) -> Iterator[Dict]:
    """
    Read articles from a JSONL file.
    
    Args:
        file_path: Path to the JSONL file
        limit: Optional maximum number of articles to read
        
    Yields:
        Dict: Article information
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if limit is not None and i >= limit:
                break
            yield json.loads(line)

def read_from_sqlite(db_path: Union[str, Path], limit: Optional[int] = None) -> Iterator[Dict]:
    """
    Read articles from a SQLite database.
    
    Args:
        db_path: Path to the SQLite database
        limit: Optional maximum number of articles to read
        
    Yields:
        Dict: Article information
    """
    conn = sqlite3.connect(db_path)
    c = conn.cursor()
    
    try:
        query = 'SELECT title, text, timestamp, is_redirect FROM articles'
        if limit is not None:
            query += f' LIMIT {limit}'
            
        for row in c.execute(query):
            yield {
                'title': row[0],
                'text': row[1],
                'timestamp': row[2],
                'is_redirect': bool(row[3])
            }
    finally:
        conn.close()


def extract_citations(text: str) -> List[str]:
    """Extract all citations from a text and clean them."""
    citations = re.findall(r'\[\[(.*?)\]\]', text)
    # Clean up citations (take first part if pipe character exists)
    return [c.split('|')[0]for c in citations]

def clean_wiki_content(text: str) -> str:
    """
    Clean wiki content by:
    1. Finding the start of the actual article (from triple quotes)
    2. Removing category tags, file links, and other metadata
    
    Args:
        text: Raw wiki text
    
    Returns:
        str: Cleaned article content starting from title
    """
    # Find the main article content starting with triple quotes
    title_match = re.search(r"'''([^']+?)'''", text)
    if not title_match:
        return text  # Return original if no title found
        
    # Get the position where the title starts
    start_pos = title_match.start()
    content = text[start_pos:]
    
    # Remove category tags
    content = re.sub(r'\[\[Category:.*?\]\]', '', content)
    
    # Remove file/image links
    content = re.sub(r'\[\[File:.*?\]\]', '', content)
    
    # Remove stub templates
    content = re.sub(r'\{\{stub\}\}', '', content)
    
    # Remove empty lines that might remain
    content = '\n'.join(line for line in content.split('\n') if line.strip())
    
    return content.strip()

def create_article_dict(jsonl_path, sample_size: Optional[int] = None) -> Dict[str, str]:
    """
    Create a dictionary of articles from the wiki parser, with cleaned content.
    
    Args:
        jsonl: the path to the saved wiki pages as [Title] -> [Text] key value format. 
        sample_size: Optional number of articles to include
        
    Returns:
        Dict[str, str]: Dictionary with article titles as keys and cleaned content as values
    """
    articles = {}
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            article = json.loads(line)
            if sample_size is not None and i >= sample_size:
                break
                
            # Clean the content to start from title
            cleaned_content = clean_wiki_content(article['text'])
            
            if cleaned_content:  # Only include if content remains after cleaning
                articles[article['title'].lower()] = cleaned_content
            
    return articles

def prepare_sources_targets(articles_dict, sample_size=1000, cite_sample=1):
    articles = np.random.permutation(list(articles_dict.keys()))[:sample_size]
    # Initialize your model (assuming it's defined elsewhere)
    model = AutoregressiveCitationMatcher()
    sources = []
    targets = []
    for article in tqdm.tqdm(articles):
        source_text = articles_dict[article]
        # Get all citations from this source article
        citations = extract_citations(source_text)
        valid_citations = [c for c in citations if c.lower() in articles_dict]
    
        for citation in np.random.choice(valid_citations,min(cite_sample,len(valid_citations))):
            try:
                source_content = model.prepare_source_context(source_text, citation)
                target_content = model.prepare_target_page(articles_dict[citation.lower()])
                sources.append(source_content)
                targets.append(target_content)
            except Exception:
                pass
    return sources, targets 


class AutoregressiveCitationMatcher(nn.Module):
    def __init__(
        self,
        model_name: str = "bert-base-uncased",
        max_length: int = 512,
        cite_token: str = "<CITE>",
        ref_token: str = "<REF>",
        device: torch.device = None
    ):
        super().__init__()
        
        # Set device
        self.device = device if device is not None else torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Add special tokens for citation and reference
        special_tokens = {
            'additional_special_tokens': [cite_token, ref_token]
        }
        self.tokenizer.add_special_tokens(special_tokens)
        
        # For GPT models, we might need to set pad token if it's not set
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Initialize encoder for both source and target texts
        # We use the same model for both since autoregressive models maintain context
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        
        # Resize token embeddings
        self.model.resize_token_embeddings(len(self.tokenizer))
        
        self.cite_token = cite_token
        self.ref_token = ref_token
        self.max_length = max_length
        
        hidden_size = self.model.config.hidden_size
        
        print(f"Using device: {self.device}")

    def prepare_source_context(
        self,
        text: str,
        target_citation: str,
        window_size: int = 100
    ) -> str:
        """Prepare source context with citation and reference tokens."""
        citation_pattern = re.escape(f"[[{target_citation}]]")
        # Add both <CITE> and <REF> tokens
        modified_text = re.sub(citation_pattern, f"{self.cite_token}", text)
        # modified_text = f"{modified_text} {self.ref_token}"

        if self.cite_token not in modified_text:
            raise ValueError(f"There is no cite token")
        
        return modified_text

    def prepare_target_page(
        self,
        page_content: str,
    ) -> str:
        """Prepare target page with reference token."""
        # Take first paragraph as summary and add <REF> token
            
        return f"{page_content} {self.ref_token}"

    def get_token_embedding(
        self,
        hidden_states: torch.Tensor,
        input_ids: torch.Tensor,
        token_id: int
    ) -> torch.Tensor:
        """Extract embedding for specific token for each item in the batch."""
        batch_size = input_ids.size(0)
        embeddings = []
        
        for batch_idx in range(batch_size):
            # Find token positions for this batch item
            token_positions = (input_ids[batch_idx] == token_id).nonzero()
            
            if len(token_positions) == 0:
                raise ValueError(f"Token id {token_id} not found in sequence for batch item {batch_idx}")
                    
            if len(token_positions) > 1:
                # If multiple occurrences, take the last one
                position = token_positions[-1].item()
            else:
                position = token_positions[0].item()
                    
            # Extract embedding for this batch item
            embedding = hidden_states[batch_idx, position, :]  # Explicitly select all hidden dimensions
            embeddings.append(embedding)
            
        # Stack all embeddings into a single tensor
        return torch.stack(embeddings)  # Shape will be [batch_size, hidden_size]

    def encode_text(
        self,
        text: str,
        is_source: bool = True
    ) -> torch.Tensor:
        """Encode text and extract reference token embedding."""
        inputs = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        ).to(self.device)
        
        # Get hidden states from the model
        outputs = self.model(
            **inputs,
            output_hidden_states=True,
            return_dict=True
        )
        
        # Get [REF] token embedding from the last hidden state
        ref_token_id = self.tokenizer.convert_tokens_to_ids(self.ref_token)
        ref_embeddings = self.get_token_embedding(
            outputs.hidden_states[-1],
            inputs['input_ids'],
            ref_token_id
        )
        return ref_embeddings
        

    def forward(
        self,
        source_contexts: List[str],
        target_pages: List[str],
        temperature: float = 0.07
    ) -> torch.Tensor:
        """Compute similarity between source and target [REF] embeddings."""
        # Encode all contexts and references
        source_embeddings = []
        target_embeddings = []
        
        for context in source_contexts:
            source_emb = self.encode_text(context, is_source=True)
            source_embeddings.append(source_emb)
            
        for page in target_pages:
            target_emb = self.encode_text(page, is_source=False)
            target_embeddings.append(target_emb)
        
        # Stack embeddings
        source_embeddings = torch.cat(source_embeddings, dim=0)
        target_embeddings = torch.cat(target_embeddings, dim=0)
        
        # Normalize embeddings
        source_embeddings = nn.functional.normalize(source_embeddings, dim=-1)
        target_embeddings = nn.functional.normalize(target_embeddings, dim=-1)
        
        # Compute similarity matrix
        similarity = torch.matmul(
            source_embeddings,
            target_embeddings.transpose(0, 1)
        ) / temperature
        
        return similarity

    def train_step(
        self,
        source_contexts: List[str],
        target_pages: List[str],
        optimizer: torch.optim.Optimizer,
        temperature: float = 0.07
    ) -> float:
        """Perform one training step."""
        optimizer.zero_grad()
        
        # Forward pass
        similarity = self(source_contexts, target_pages, temperature)
        
        # The diagonal elements should be the positive pairs
        labels = torch.arange(len(source_contexts)).to(self.device)
        # print('#'*20)
        # print('similarity = ', similarity, ', labels = ', labels)
        
        # Compute loss
        loss = nn.CrossEntropyLoss()(similarity, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        return loss.item()



class CitationDataset(Dataset):
    def __init__(
        self,
        sources: List[str],
        targets: List[str],
        tokenizer,
        max_length: int = 512,
        verbose: bool = True
    ):
        """
        Initialize dataset with preprocessing of all samples.
        
        Args:
            sources: List of source texts
            targets: List of target texts
            tokenizer: Tokenizer to use
            max_length: Maximum sequence length
            verbose: Whether to print processing statistics
        """
        assert len(sources) == len(targets), "Sources and targets must have same length"
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.ref_token = "<REF>"
        self.cite_token = "<CITE>"
        
        # Get token IDs
        self.ref_token_id = self.tokenizer.convert_tokens_to_ids(self.ref_token)
        self.cite_token_id = self.tokenizer.convert_tokens_to_ids(self.cite_token)
        
        if verbose:
            print("Pre-processing dataset...")
            print(f"Initial samples: {len(sources)}")
        
        # Pre-process and store all valid samples
        self.processed_samples = self._preprocess_samples(sources, targets, verbose)
        
        if verbose:
            print(f"Final processed samples: {len(self.processed_samples)}")
            print("Dataset preprocessing completed.")

    def _preprocess_samples(
        self,
        sources: List[str],
        targets: List[str],
        verbose: bool
    ) -> List[Dict[str, torch.Tensor]]:
        """
        Pre-process all samples and filter invalid ones.
        Returns list of valid processed samples.
        """
        processed_samples = []
        skipped_no_cite = 0
        skipped_errors = 0
        
        for idx, (source, target) in enumerate(zip(sources, targets)):
            try:
                # Tokenize source and target
                source_tokens = self.tokenizer(
                    source,
                    padding='max_length',
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors='pt'
                )
                
                target_tokens = self.tokenizer(
                    target,
                    padding='max_length',
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors='pt'
                )
                
                
                # Ensure REF token in both source and target
                # source_tokens['input_ids'] = self.ensure_ref_token(source_tokens['input_ids'])
                target_tokens['input_ids'] = self.ensure_ref_token(target_tokens['input_ids'])
                
                # Store processed tensors
                processed_sample = {
                    'source_input_ids': source_tokens['input_ids'].squeeze(0),
                    'source_attention_mask': source_tokens['attention_mask'].squeeze(0),
                    'target_input_ids': target_tokens['input_ids'].squeeze(0),
                    'target_attention_mask': target_tokens['attention_mask'].squeeze(0)
                }
                
                # Check if CITE token exists in source after tokenization
                if self.cite_token_id not in source_tokens['input_ids'][0]:
                    skipped_no_cite += 1
                    continue
                
                processed_samples.append(processed_sample)
                
            except Exception as e:
                skipped_errors += 1
                if verbose:
                    print(f"Error processing sample {idx}: {str(e)}")
                continue
            
            # Print progress periodically
            if verbose and (idx + 1) % 10000 == 0:
                print(f"Processed {idx + 1} samples...")
        
        if verbose:
            print(f"Skipped {skipped_no_cite} samples due to missing CITE token")
            print(f"Skipped {skipped_errors} samples due to processing errors")
        
        return processed_samples

    def ensure_ref_token(self, tokens: torch.Tensor) -> torch.Tensor:
        """Ensure the REF token is present in the sequence, replacing last non-pad token if needed."""
        batch_size = tokens.size(0)
        
        # For each sequence in the batch
        for i in range(batch_size):
            sequence = tokens[i]
            # Find positions of REF token
            ref_positions = (sequence == self.ref_token_id).nonzero()
            
            if len(ref_positions) == 0:
                # If no REF token found, replace the last non-pad token
                pad_token_id = self.tokenizer.pad_token_id
                # Find the last non-pad token position
                non_pad_positions = (sequence != pad_token_id).nonzero()
                if len(non_pad_positions) > 0:
                    last_non_pad_pos = non_pad_positions[-1]
                    sequence[last_non_pad_pos] = self.ref_token_id
            elif len(ref_positions) > 1:
                # If multiple REF tokens, keep only the last one
                for pos in ref_positions[:-1]:
                    sequence[pos] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.unk_token)
        
        return tokens

    def __len__(self):
        return len(self.processed_samples)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        """Return pre-processed sample."""
        return self.processed_samples[idx]

def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    """Custom collate function to handle batching of tokenized inputs."""
    return {
        key: torch.stack([item[key] for item in batch])
        for key in batch[0].keys()
    }

def train_model(
    model: AutoregressiveCitationMatcher,
    train_sources: List[str],
    train_targets: List[str],
    val_sources: List[str] = None,
    val_targets: List[str] = None,
    batch_size: int = 32,
    num_epochs: int = 10,
    learning_rate: float = 1e-4,
    temperature: float = 0.07,
    num_workers: int = 4,
    validation_steps: int = 100,
    save_path: str = 'checkpoint.pt'
):
    """
    Memory-optimized training function that only keeps the last hidden layer.
    """
    # Create datasets
    train_dataset = CitationDataset(
        train_sources,
        train_targets,
        model.tokenizer,
        model.max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )

    if val_sources and val_targets:
        val_dataset = CitationDataset(
            val_sources,
            val_targets,
            model.tokenizer,
            model.max_length
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            collate_fn=collate_fn,
            pin_memory=True
        )

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criteria = torch.nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=2,
        verbose=True
    )

    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        num_batches = 0
        
        progress_bar = tqdm.tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}')
        
        for batch in progress_bar:
            batch = {k: v.to(model.device) for k, v in batch.items()}
            
            # Forward pass - only get last hidden layer
            source_outputs = model.model(
                input_ids=batch['source_input_ids'],
                attention_mask=batch['source_attention_mask'],
                output_hidden_states=False,  # Only get last layer
                return_dict=True
            )
            
            target_outputs = model.model(
                input_ids=batch['target_input_ids'],
                attention_mask=batch['target_attention_mask'],
                output_hidden_states=False,  # Only get last layer
                return_dict=True
            )
            
            # Get token positions efficiently
            cite_token_id = model.tokenizer.convert_tokens_to_ids(model.cite_token)
            ref_token_id = model.tokenizer.convert_tokens_to_ids(model.ref_token)

            source_embeddings = model.get_token_embedding(
                source_outputs.last_hidden_state,
                batch['source_input_ids'],
                cite_token_id
            )
            
            target_embeddings = model.get_token_embedding(
                target_outputs.last_hidden_state,
                batch['target_input_ids'],
                ref_token_id
            )
            
            # Normalize embeddings
            source_embeddings = torch.nn.functional.normalize(source_embeddings, dim=-1)
            target_embeddings = torch.nn.functional.normalize(target_embeddings, dim=-1)
            
            # Compute similarity matrix
            similarity = torch.matmul(source_embeddings, target_embeddings.transpose(0, 1)) / temperature
            
            # Compute loss
            labels = torch.arange(similarity.size(0)).to(model.device)
            loss = criteria(similarity, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_loss': f'{total_loss / num_batches:.4f}'
            })
            
            # Free memory
            del source_outputs, target_outputs, source_embeddings, target_embeddings, similarity
            torch.cuda.empty_cache()  # Optional, can help with memory fragmentation
            
            
        # Validation
        if val_sources and val_targets:
            val_loss = validate_model(model, val_loader, temperature)
            scheduler.step(val_loss)
    
            print(f"Validation loss: {val_loss:0.4f}")
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), save_path)
                print(f"\nSaved new best model with validation loss: {val_loss:.4f}")
            
            model.train()

                
        epoch_loss = total_loss / num_batches
        print(f"\nEpoch {epoch + 1} average loss: {epoch_loss:.4f}")

def validate_model(
    model: AutoregressiveCitationMatcher,
    val_loader: DataLoader,
    temperature: float
) -> float:
    """
    Memory-optimized validation function that only keeps the last hidden layer.
    """
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm.tqdm(val_loader, desc='Validation'):
            batch = {k: v.to(model.device) for k, v in batch.items()}
            
            # Forward pass - only get last hidden layer
            source_outputs = model.model(
                input_ids=batch['source_input_ids'],
                attention_mask=batch['source_attention_mask'],
                output_hidden_states=False,
                return_dict=True
            )
            
            target_outputs = model.model(
                input_ids=batch['target_input_ids'],
                attention_mask=batch['target_attention_mask'],
                output_hidden_states=False,
                return_dict=True
            )
            
            try:
                # Get token positions efficiently
                cite_token_id = model.tokenizer.convert_tokens_to_ids(model.cite_token)
                ref_token_id = model.tokenizer.convert_tokens_to_ids(model.ref_token)
                
                source_embeddings = model.get_token_embedding(
                    source_outputs.last_hidden_state,
                    batch['source_input_ids'],
                    cite_token_id
                )
                
                target_embeddings = model.get_token_embedding(
                    target_outputs.last_hidden_state,
                    batch['target_input_ids'],
                    ref_token_id
                )
                
                # Normalize embeddings
                source_embeddings = torch.nn.functional.normalize(source_embeddings, dim=-1)
                target_embeddings = torch.nn.functional.normalize(target_embeddings, dim=-1)
                
                # Compute similarity and loss
                similarity = torch.matmul(source_embeddings, target_embeddings.transpose(0, 1)) / temperature
                
                labels = torch.arange(similarity.size(0)).to(model.device)
                loss = torch.nn.CrossEntropyLoss()(similarity, labels)
                
                total_loss += loss.item()
                num_batches += 1
                
                # Free memory
                del source_outputs, target_outputs, source_embeddings, target_embeddings, similarity
                
            except ValueError as e:
                continue
    
    return total_loss / num_batches if num_batches > 0 else float('inf')

# Usage example:
# Split data into train and validation sets
def split_data(sources: List[str], targets: List[str], val_ratio: float = 0.1) -> Tuple[List[str], List[str], List[str], List[str]]:
    """Split data into training and validation sets."""
    val_size = int(len(sources) * val_ratio)
    indices = torch.randperm(len(sources))
    
    train_indices = indices[val_size:]
    val_indices = indices[:val_size]
    
    train_sources = [sources[i] for i in train_indices]
    train_targets = [targets[i] for i in train_indices]
    val_sources = [sources[i] for i in val_indices]
    val_targets = [targets[i] for i in val_indices]
    
    return train_sources, train_targets, val_sources, val_targets



# wiki_parser = WikiDumpParser('./data/wiki/simplewiki-latest-pages-articles.xml')
# wiki_parser.save_to_jsonl('./wiki_articles.jsonl')

articles_dict = create_article_dict(jsonl_path='./wiki_articles.jsonl')
sources, targets = prepare_sources_targets(articles_dict, sample_size = 2000)

# Split the data
train_sources, train_targets, val_sources, val_targets = split_data(sources, targets)

# Initialize model (assuming it's already defined)
model = AutoregressiveCitationMatcher()

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Train the model
train_model(
    model=model,
    train_sources=train_sources,
    train_targets=train_targets,
    val_sources=val_sources,
    val_targets=val_targets,
    batch_size=16,
    num_epochs=10,
    learning_rate=1.5e-4,
    temperature=0.1,
    num_workers=4,
    validation_steps=50,
    save_path='best_model.pt'
)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


100%|████████████████████████████████████| 2000/2000 [00:00<00:00, 11186.99it/s]


Using device: cuda
Pre-processing dataset...
Initial samples: 1342


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from typing import List, Tuple, Dict
import numpy as np
import gc

class CitationDataset(Dataset):
    def __init__(
        self,
        sources: List[str],
        targets: List[str],
        tokenizer,
        max_length: int = 512
    ):
        assert len(sources) == len(targets), "Sources and targets must have same length"
        self.sources = sources
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sources)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        source = self.sources[idx]
        target = self.targets[idx]

        # Tokenize source and target texts
        source_tokens = self.tokenizer(
            source,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        target_tokens = self.tokenizer(
            target,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Remove the batch dimension added by tokenizer
        return {
            'source_input_ids': source_tokens['input_ids'].squeeze(0),
            'source_attention_mask': source_tokens['attention_mask'].squeeze(0),
            'target_input_ids': target_tokens['input_ids'].squeeze(0),
            'target_attention_mask': target_tokens['attention_mask'].squeeze(0)
        }

def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    """Custom collate function to handle batching of tokenized inputs."""
    return {
        key: torch.stack([item[key] for item in batch])
        for key in batch[0].keys()
    }

def print_gpu_memory_status():
    """Print current GPU memory usage."""
    if torch.cuda.is_available():
        print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
        print(f"Memory cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

def clean_memory():
    """Clean up GPU memory."""
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

def validate_model(
    model: AutoregressiveCitationMatcher,
    val_loader: DataLoader,
    temperature: float,
    scaler: GradScaler
) -> float:
    """
    Validate the model on validation set.
    
    Returns:
        float: Average validation loss
    """
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validation'):
            batch = {k: v.to(model.device) for k, v in batch.items()}
            
            with autocast():
                # Process source sequences
                source_outputs = model.model(
                    input_ids=batch['source_input_ids'],
                    attention_mask=batch['source_attention_mask'],
                    output_hidden_states=True,
                    return_dict=True
                )
                
                clean_memory()
                
                # Process target sequences
                target_outputs = model.model(
                    input_ids=batch['target_input_ids'],
                    attention_mask=batch['target_attention_mask'],
                    output_hidden_states=True,
                    return_dict=True
                )
                
                try:
                    # Get embeddings and compute loss
                    ref_token_id = model.tokenizer.convert_tokens_to_ids(model.ref_token)
                    
                    source_embeddings = model.get_token_embedding(
                        source_outputs.hidden_states[-1],
                        batch['source_input_ids'],
                        ref_token_id
                    )
                    
                    target_embeddings = model.get_token_embedding(
                        target_outputs.hidden_states[-1],
                        batch['target_input_ids'],
                        ref_token_id
                    )
                    
                    # Normalize embeddings
                    source_embeddings = torch.nn.functional.normalize(source_embeddings, dim=-1)
                    target_embeddings = torch.nn.functional.normalize(target_embeddings, dim=-1)
                    
                    # Compute similarity and loss
                    similarity = torch.matmul(
                        source_embeddings,
                        target_embeddings.transpose(0, 1)
                    ) / temperature
                    
                    labels = torch.arange(similarity.size(0)).to(model.device)
                    loss = torch.nn.CrossEntropyLoss()(similarity, labels)
                    
                    total_loss += loss.item()
                    num_batches += 1
                    
                except ValueError as e:
                    print(f"Skipping validation batch due to error: {str(e)}")
                    continue
                
                clean_memory()
    
    return total_loss / num_batches if num_batches > 0 else float('inf')

def train_model(
    model: AutoregressiveCitationMatcher,
    train_sources: List[str],
    train_targets: List[str],
    val_sources: List[str] = None,
    val_targets: List[str] = None,
    batch_size: int = 8,
    gradient_accumulation_steps: int = 4,
    num_epochs: int = 10,
    learning_rate: float = 1e-4,
    temperature: float = 0.07,
    num_workers: int = 4,
    validation_steps: int = 100,
    save_path: str = 'checkpoint.pt',
    max_length: int = 512
):
    """
    Train the citation matcher model with memory optimizations.
    
    Args:
        model: The AutoregressiveCitationMatcher model
        train_sources: List of source contexts
        train_targets: List of target pages
        val_sources: Optional validation source contexts
        val_targets: Optional validation target pages
        batch_size: Batch size for training
        gradient_accumulation_steps: Number of steps to accumulate gradients
        num_epochs: Number of training epochs
        learning_rate: Learning rate for optimizer
        temperature: Temperature for similarity computation
        num_workers: Number of workers for data loading
        validation_steps: Number of steps between validations
        save_path: Path to save the best model
        max_length: Maximum sequence length
    """
    print(f"Using device: {model.device}")
    print_gpu_memory_status()
    
    # Create datasets with specified max_length
    train_dataset = CitationDataset(
        train_sources,
        train_targets,
        model.tokenizer,
        max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )

    if val_sources and val_targets:
        val_dataset = CitationDataset(
            val_sources,
            val_targets,
            model.tokenizer,
            max_length
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            collate_fn=collate_fn,
            pin_memory=True
        )

    # Initialize optimizer and scaler for mixed precision training
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scaler = GradScaler()
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=2,
        verbose=True
    )

    best_val_loss = float('inf')
    global_step = 0
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        num_batches = 0
        
        progress_bar = tqdm.tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}')
        
        for batch_idx, batch in enumerate(progress_bar):
            # Move batch to device
            batch = {k: v.to(model.device) for k, v in batch.items()}
            
            # Mixed precision training
            with autocast():
                try:
                    # Process source sequences
                    source_outputs = model.model(
                        input_ids=batch['source_input_ids'],
                        attention_mask=batch['source_attention_mask'],
                        output_hidden_states=True,
                        return_dict=True
                    )
                    
                    clean_memory()
                    
                    # Process target sequences
                    target_outputs = model.model(
                        input_ids=batch['target_input_ids'],
                        attention_mask=batch['target_attention_mask'],
                        output_hidden_states=True,
                        return_dict=True
                    )
                    
                    # Get [REF] token embeddings
                    ref_token_id = model.tokenizer.convert_tokens_to_ids(model.ref_token)
                    
                    source_embeddings = model.get_token_embedding(
                        source_outputs.hidden_states[-1],
                        batch['source_input_ids'],
                        ref_token_id
                    )
                    
                    target_embeddings = model.get_token_embedding(
                        target_outputs.hidden_states[-1],
                        batch['target_input_ids'],
                        ref_token_id
                    )
                    
                    # Normalize embeddings
                    source_embeddings = torch.nn.functional.normalize(source_embeddings, dim=-1)
                    target_embeddings = torch.nn.functional.normalize(target_embeddings, dim=-1)
                    
                    # Compute similarity matrix
                    similarity = torch.matmul(
                        source_embeddings,
                        target_embeddings.transpose(0, 1)
                    ) / temperature
                    
                    # Compute loss
                    labels = torch.arange(similarity.size(0)).to(model.device)
                    loss = torch.nn.CrossEntropyLoss()(similarity, labels)
                    
                    # Scale loss by gradient accumulation steps
                    loss = loss / gradient_accumulation_steps
                    
                    # Backward pass with gradient scaling
                    scaler.scale(loss).backward()
                    
                    if (batch_idx + 1) % gradient_accumulation_steps == 0:
                        # Unscale gradients and clip
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        
                        # Step optimizer and scaler
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()
                    
                    total_loss += loss.item() * gradient_accumulation_steps
                    num_batches += 1
                    
                    # Update progress bar
                    progress_bar.set_postfix({
                        'loss': f'{loss.item() * gradient_accumulation_steps:.4f}',
                        'avg_loss': f'{total_loss / num_batches:.4f}'
                    })
                    
                    global_step += 1
                    
                    # Validation
                    if val_sources and val_targets and global_step % validation_steps == 0:
                        val_loss = validate_model(
                            model,
                            val_loader,
                            temperature,
                            scaler
                        )
                        
                        # Update learning rate scheduler
                        scheduler.step(val_loss)
                        
                        # Save best model
                        if val_loss < best_val_loss:
                            best_val_loss = val_loss
                            torch.save({
                                'epoch': epoch,
                                'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'scheduler_state_dict': scheduler.state_dict(),
                                'scaler_state_dict': scaler.state_dict(),
                                'loss': best_val_loss,
                            }, save_path)
                            print(f"\nSaved new best model with validation loss: {val_loss:.4f}")
                        
                        model.train()  # Switch back to training mode
                    
                except ValueError as e:
                    print(f"Skipping batch due to error: {str(e)}")
                    continue
                
                # Clean memory periodically
                if batch_idx % 10 == 0:
                    clean_memory()
                
        # End of epoch
        epoch_loss = total_loss / num_batches
        print(f"\nEpoch {epoch + 1} average loss: {epoch_loss:.4f}")
        print_gpu_memory_status()

def split_data(
    sources: List[str],
    targets: List[str],
    val_ratio: float = 0.1
) -> Tuple[List[str], List[str], List[str], List[str]]:
    """Split data into training and validation sets."""
    val_size = int(len(sources) * val_ratio)
    indices = torch.randperm(len(sources))
    
    train_indices = indices[val_size:]
    val_indices = indices[:val_size]
    
    train_sources = [sources[i] for i in train_indices]
    train_targets = [targets[i] for i in train_indices]
    val_sources = [sources[i] for i in val_indices]
    val_targets = [targets[i] for i in val_indices]
    
    return train_sources, train_targets, val_sources, val_targets

# Usage example:
if __name__ == "__main__":
    # Split the data
    train_sources, train_targets, val_sources, val_targets = split_data(sources, targets)

    # Initialize model
    model = AutoregressiveCitationMatcher()

    # Train with memory optimizations
    train_model(
        model=model,
        train_sources=train_sources,
        train_targets=train_targets,
        val_sources=val_sources,
        val_targets=val_targets,
        batch_size=8,  # Reduced batch size
        gradient_accumulation_steps=4,  # Effective batch size = 8 * 4 = 32
        num_epochs=10,
        learning_rate=1e-4,
        temperature=0.07,
        num_workers=4,
        validation_steps=100,
        save_path='best_model.pt',
        max_length=512  # Adjust based on your GPU memory
    )

  scaler = GradScaler()


Using device: cuda
Using device: cuda
Memory allocated: 3.99 GB
Max memory allocated: 12.00 GB
Memory cached: 12.49 GB


  with autocast():
Epoch 1/10:   0%|                           | 1/11028 [00:01<3:07:07,  1.02s/it]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                           | 2/11028 [00:01<1:44:26,  1.76it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                           | 3/11028 [00:01<1:18:49,  2.33it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                           | 4/11028 [00:01<1:06:40,  2.76it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                             | 5/11028 [00:02<58:22,  3.15it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                             | 6/11028 [00:02<54:14,  3.39it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                             | 7/11028 [00:02<51:24,  3.57it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                             | 8/11028 [00:02<48:02,  3.82it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                             | 9/11028 [00:02<46:52,  3.92it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 10/11028 [00:03<45:38,  4.02it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 11/11028 [00:03<43:51,  4.19it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 12/11028 [00:03<44:01,  4.17it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 14/11028 [00:04<40:54,  4.49it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 15/11028 [00:04<41:00,  4.48it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 16/11028 [00:04<40:02,  4.58it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 17/11028 [00:04<39:15,  4.68it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 19/11028 [00:05<38:43,  4.74it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 20/11028 [00:05<37:58,  4.83it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 22/11028 [00:05<38:18,  4.79it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 23/11028 [00:05<37:37,  4.88it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 25/11028 [00:06<37:42,  4.86it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 26/11028 [00:06<37:02,  4.95it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 28/11028 [00:06<36:42,  4.99it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 30/11028 [00:07<36:25,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 31/11028 [00:07<36:29,  5.02it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 32/11028 [00:07<36:37,  5.00it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 34/11028 [00:08<36:47,  4.98it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 36/11028 [00:08<36:24,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 38/11028 [00:08<36:18,  5.04it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 40/11028 [00:09<36:00,  5.09it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 42/11028 [00:09<36:28,  5.02it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 44/11028 [00:10<35:59,  5.09it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 46/11028 [00:10<35:34,  5.14it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|                            | 48/11028 [00:10<35:53,  5.10it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|▏                           | 50/11028 [00:11<36:10,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|▏                           | 52/11028 [00:11<35:50,  5.10it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   0%|▏                           | 54/11028 [00:12<35:59,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 56/11028 [00:12<36:09,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 57/11028 [00:12<36:21,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 59/11028 [00:13<36:19,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 61/11028 [00:13<36:14,  5.04it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 63/11028 [00:13<35:56,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 64/11028 [00:14<35:40,  5.12it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 66/11028 [00:14<36:54,  4.95it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 68/11028 [00:14<36:39,  4.98it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 70/11028 [00:15<36:44,  4.97it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 72/11028 [00:15<36:10,  5.05it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 74/11028 [00:16<36:14,  5.04it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 76/11028 [00:16<36:16,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 78/11028 [00:16<35:56,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 80/11028 [00:17<35:38,  5.12it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 82/11028 [00:17<35:26,  5.15it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 84/11028 [00:18<35:28,  5.14it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 86/11028 [00:18<35:58,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 87/11028 [00:18<35:55,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 88/11028 [00:18<36:08,  5.04it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 90/11028 [00:19<36:11,  5.04it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 92/11028 [00:19<36:05,  5.05it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 93/11028 [00:19<36:56,  4.93it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 95/11028 [00:20<36:22,  5.01it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 97/11028 [00:20<35:50,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                           | 98/11028 [00:20<35:53,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                           | 99/11028 [00:21<36:14,  5.02it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                          | 101/11028 [00:21<36:29,  4.99it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▏                          | 102/11028 [00:21<36:14,  5.02it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 104/11028 [00:22<36:25,  5.00it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 106/11028 [00:22<35:56,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 108/11028 [00:22<35:26,  5.14it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 110/11028 [00:23<35:20,  5.15it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 112/11028 [00:23<35:15,  5.16it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 114/11028 [00:24<35:03,  5.19it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 116/11028 [00:24<35:08,  5.18it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 118/11028 [00:24<35:31,  5.12it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 120/11028 [00:25<35:10,  5.17it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 122/11028 [00:25<34:52,  5.21it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 124/11028 [00:25<34:56,  5.20it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 126/11028 [00:26<35:07,  5.17it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 128/11028 [00:26<35:00,  5.19it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 130/11028 [00:27<34:48,  5.22it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 132/11028 [00:27<34:48,  5.22it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 134/11028 [00:27<35:09,  5.16it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 136/11028 [00:28<35:11,  5.16it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 138/11028 [00:28<35:21,  5.13it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 140/11028 [00:29<34:38,  5.24it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 142/11028 [00:29<35:16,  5.14it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 144/11028 [00:29<35:06,  5.17it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 146/11028 [00:30<35:05,  5.17it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 148/11028 [00:30<34:43,  5.22it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 150/11028 [00:30<34:41,  5.23it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▎                          | 152/11028 [00:31<35:04,  5.17it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▍                          | 154/11028 [00:31<35:02,  5.17it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▍                          | 156/11028 [00:32<34:44,  5.21it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▍                          | 158/11028 [00:32<34:45,  5.21it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▍                          | 160/11028 [00:32<34:23,  5.27it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▍                          | 162/11028 [00:33<34:57,  5.18it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   1%|▍                          | 164/11028 [00:33<35:04,  5.16it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 166/11028 [00:34<35:17,  5.13it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 168/11028 [00:34<35:01,  5.17it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 169/11028 [00:34<34:54,  5.18it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 171/11028 [00:35<35:32,  5.09it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 173/11028 [00:35<35:39,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 175/11028 [00:35<35:18,  5.12it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 176/11028 [00:36<35:12,  5.14it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 178/11028 [00:36<35:42,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 180/11028 [00:36<35:21,  5.11it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 182/11028 [00:37<35:44,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 184/11028 [00:37<35:43,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 186/11028 [00:37<35:55,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 188/11028 [00:38<35:07,  5.14it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 190/11028 [00:38<35:07,  5.14it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 192/11028 [00:39<35:13,  5.13it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 193/11028 [00:39<35:13,  5.13it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 195/11028 [00:39<35:40,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 196/11028 [00:39<35:42,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 198/11028 [00:40<35:24,  5.10it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 200/11028 [00:40<35:24,  5.10it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 202/11028 [00:41<35:22,  5.10it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▍                          | 204/11028 [00:41<35:07,  5.14it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 206/11028 [00:41<34:54,  5.17it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 208/11028 [00:42<34:38,  5.21it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 210/11028 [00:42<34:55,  5.16it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 212/11028 [00:43<34:29,  5.23it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 214/11028 [00:43<34:40,  5.20it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 216/11028 [00:43<34:47,  5.18it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 218/11028 [00:44<35:10,  5.12it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 220/11028 [00:44<34:56,  5.16it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 222/11028 [00:45<35:18,  5.10it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 224/11028 [00:45<35:29,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 225/11028 [00:45<35:26,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 227/11028 [00:45<35:24,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 228/11028 [00:46<35:05,  5.13it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 230/11028 [00:46<35:40,  5.04it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 232/11028 [00:46<35:19,  5.09it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 233/11028 [00:47<36:15,  4.96it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 235/11028 [00:47<35:44,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 236/11028 [00:47<35:07,  5.12it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 237/11028 [00:47<35:36,  5.05it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 239/11028 [00:48<35:36,  5.05it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 241/11028 [00:48<35:53,  5.01it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 243/11028 [00:49<35:33,  5.05it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 244/11028 [00:49<35:31,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 246/11028 [00:49<35:44,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 247/11028 [00:49<35:28,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 250/11028 [00:50<35:16,  5.09it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 252/11028 [00:50<35:05,  5.12it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▌                          | 254/11028 [00:51<35:12,  5.10it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 256/11028 [00:51<34:58,  5.13it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 257/11028 [00:51<35:02,  5.12it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 258/11028 [00:52<35:22,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 259/11028 [00:52<35:36,  5.04it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 261/11028 [00:52<35:50,  5.01it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 263/11028 [00:53<35:06,  5.11it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 265/11028 [00:53<35:02,  5.12it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 267/11028 [00:53<35:10,  5.10it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 269/11028 [00:54<35:17,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 271/11028 [00:54<35:03,  5.11it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 273/11028 [00:55<35:08,  5.10it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   2%|▋                          | 275/11028 [00:55<34:49,  5.15it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 277/11028 [00:55<35:25,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 279/11028 [00:56<35:21,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 281/11028 [00:56<35:27,  5.05it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 283/11028 [00:57<35:24,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 284/11028 [00:57<35:36,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 286/11028 [00:57<35:48,  5.00it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 288/11028 [00:58<35:52,  4.99it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 290/11028 [00:58<35:38,  5.02it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 292/11028 [00:58<35:05,  5.10it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 294/11028 [00:59<35:06,  5.09it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 296/11028 [00:59<35:17,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 298/11028 [01:00<35:11,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 299/11028 [01:00<35:17,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 301/11028 [01:00<35:16,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 303/11028 [01:01<35:16,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▋                          | 305/11028 [01:01<35:32,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 307/11028 [01:01<35:28,  5.04it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 309/11028 [01:02<35:24,  5.04it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 311/11028 [01:02<35:15,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 313/11028 [01:03<35:05,  5.09it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 314/11028 [01:03<35:22,  5.05it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 315/11028 [01:03<35:45,  4.99it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 317/11028 [01:03<35:52,  4.97it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 319/11028 [01:04<35:25,  5.04it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 321/11028 [01:04<35:27,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 323/11028 [01:04<35:05,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 325/11028 [01:05<35:01,  5.09it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 327/11028 [01:05<35:02,  5.09it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 328/11028 [01:05<35:26,  5.03it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 331/11028 [01:06<34:49,  5.12it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 333/11028 [01:06<35:04,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 335/11028 [01:07<35:14,  5.06it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 337/11028 [01:07<35:04,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 339/11028 [01:08<34:46,  5.12it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 341/11028 [01:08<34:41,  5.13it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 342/11028 [01:08<34:50,  5.11it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 344/11028 [01:09<35:16,  5.05it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 345/11028 [01:09<35:05,  5.07it/s]

Skipping batch due to error: There should only be one occurance of REF token


Epoch 1/10:   3%|▊                          | 347/11028 [01:09<35:03,  5.08it/s]

Skipping batch due to error: There should only be one occurance of REF token
Skipping batch due to error: There should only be one occurance of REF token


Exception in thread Thread-7 (_pin_memory_loop):
Traceback (most recent call last):
  File "/home/amir/miniconda3/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/home/amir/miniconda3/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/home/amir/miniconda3/lib/python3.10/site-packages/torch/utils/data/_utils/pin_memory.py", line 55, in _pin_memory_loop
    do_one_step()
  File "/home/amir/miniconda3/lib/python3.10/site-packages/torch/utils/data/_utils/pin_memory.py", line 32, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/home/amir/miniconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/amir/miniconda3/lib/python3.10/site-packages/torch/multiprocessing/reductions.py", line 496, in rebuild_storage_fd
    fd = df.detach()
  File "/home/amir/miniconda3/lib/python3.10/multiprocessing/resource_sharer.py", lin

KeyboardInterrupt: 