In [1]:
# SECTION 2: RANDOM RETRIEVAL EXPERIMENT
# LLM prediction with randomly sampled historical candidates
# Tests whether ANY historical context helps vs. SMART retrieval

import os
import torch
import json
import pandas as pd
import numpy as np
import random
import pickle
from datetime import datetime
from pathlib import Path
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, List, Tuple

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class RandomRetrievalExperiment:
    """
    Random retrieval experiment: LLM + randomly sampled historical candidates
    Controls for the effect of having ANY historical context vs. smart retrieval
    """
    
    def __init__(self, test_queries_file: str, ground_truth_file: str, embeddings_dir: str):
        self.test_queries_file = test_queries_file
        self.ground_truth_file = ground_truth_file
        self.embeddings_dir = embeddings_dir
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Load data
        self.test_queries = self._load_json_file(test_queries_file)
        self.ground_truth = self._load_json_file(ground_truth_file)
        self.gt_lookup = {gt['query_id']: gt for gt in self.ground_truth}
        
        # Load candidates from embeddings
        self.candidates = self._load_candidates()
        
        logger.info(f"Loaded {len(self.test_queries)} test queries")
        logger.info(f"Loaded {len(self.ground_truth)} ground truth entries")
        logger.info(f"Loaded {len(self.candidates)} candidates for random sampling")
        
        # LLM configurations
        self.llm_configs = {
            'StockLLM': {
                'model_name': 'ElsaShaw/StockLLM',
                'description': 'Specialized financial LLM'
            },
            'Llama3.2-3B': {
                'model_name': 'meta-llama/Llama-3.2-3B-Instruct',
                'description': 'Medium general-purpose LLM'
            },
            'Qwen2.5-1.5B': {
                'model_name': 'Qwen/Qwen2.5-1.5B-Instruct',
                'description': 'Qwen instruction-following model'
            },
            'Phi3-Mini': {
                'model_name': 'microsoft/Phi-3-mini-4k-instruct',
                'description': 'Microsoft compact LLM'
            }
        }
        
        # Current LLM state
        self.current_llm = None
        self.current_tokenizer = None
        self.current_llm_name = None
    
    def _load_json_file(self, file_path: str) -> List[Dict]:
        """Load JSONL file"""
        data = []
        with open(file_path, 'r') as f:
            for line in f:
                try:
                    data.append(json.loads(line.strip()))
                except json.JSONDecodeError:
                    continue
        return data
    
    def _load_candidates(self) -> Dict:
        """Load all candidates from embedding files"""
        candidates = {}
        embedding_files = list(Path(self.embeddings_dir).glob("c_*_FinQuest_embeddings_*.pkl"))
        
        logger.info(f"Loading candidates from {len(embedding_files)} files...")
        
        for file_path in embedding_files:
            with open(file_path, 'rb') as f:
                embedding_data = pickle.load(f)
            
            for date_group in embedding_data:
                for date, candidates_on_date in date_group.items():
                    for candidate_item in candidates_on_date:
                        candidate_data = candidate_item['data']
                        candidates[candidate_data['data_index']] = candidate_data
        
        return candidates
    
    def load_llm(self, llm_name: str):
        """Load specified LLM for testing"""
        if llm_name not in self.llm_configs:
            raise ValueError(f"Unknown LLM: {llm_name}")
        
        # Clear previous model
        if self.current_llm is not None:
            del self.current_llm
            del self.current_tokenizer
            torch.cuda.empty_cache()
        
        config = self.llm_configs[llm_name]
        model_name = config['model_name']
        
        logger.info(f"Loading {llm_name}: {model_name}")
        
        try:
            self.current_tokenizer = AutoTokenizer.from_pretrained(model_name)
            if self.current_tokenizer.pad_token is None:
                self.current_tokenizer.pad_token = self.current_tokenizer.eos_token
            
            self.current_llm = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16 if self.device.type == 'cuda' else torch.float32,
                device_map="auto" if self.device.type == 'cuda' else None,
                low_cpu_mem_usage=True,
                trust_remote_code=True
            )
            
            self.current_llm_name = llm_name
            logger.info(f"✅ Successfully loaded {llm_name}")
            
        except Exception as e:
            logger.error(f"❌ Failed to load {llm_name}: {e}")
            raise
    
    def get_qualified_candidates(self, query_date: str) -> List[Dict]:
        """Get candidates that occur before the query date (temporal safety)"""
        query_dt = datetime.strptime(query_date, "%Y-%m-%d")
        qualified = []
        
        for candidate in self.candidates.values():
            try:
                candidate_dt = datetime.strptime(candidate['candidate_date'], "%Y-%m-%d")
                if query_dt > candidate_dt:  # Only past candidates
                    qualified.append(candidate)
            except (ValueError, KeyError):
                continue  # Skip malformed dates
        
        return qualified
    
    def sample_random_candidates(self, qualified_candidates: List[Dict], k: int = 5) -> List[Dict]:
        """Randomly sample k candidates from qualified pool"""
        if len(qualified_candidates) <= k:
            return qualified_candidates
        else:
            return random.sample(qualified_candidates, k)
    
    def generate_random_retrieval_prompt(self, query: Dict, random_candidates: List[Dict]) -> Tuple[str, List[int]]:
        """
        Generate prompt with randomly retrieved historical candidates
        Same format as FinQuest retrieval but with random selection
        """
        query_stock = query.get('query_stock', 'Unknown')
        query_date = query.get('query_date', 'Unknown')
        
        instruction = (
            "Based on the following information, predict stock movement by filling in the [blank] with 'rise' or 'fall'. "
            "Just fill in the blank, do not explain.\n"
        )
        
        retrieve_prompt = 'These are randomly selected sequences that may affect this stock\'s price recently:\n'
        
        # Format random candidates
        candidate_text = ""
        candidate_movement_count = [0, 0]  # [rise, fall]
        
        for candidate in random_candidates:
            # Track movement distribution
            movement = candidate.get('movement', 'unknown')
            if movement == 'rise':
                candidate_movement_count[0] += 1
            elif movement == 'fall':
                candidate_movement_count[1] += 1
            
            # Format candidate sequence
            candidate_sequence = {
                'candidate_stock': candidate.get('candidate_stock', 'Unknown'),
                'candidate_date': candidate.get('candidate_date', 'Unknown'),
                'recent_date_list': candidate.get('recent_date_list', []),
                'adjusted_close_list': candidate.get('adjusted_close_list', [])
            }
            
            candidate_text += str({'candidate_sequence': candidate_sequence}) + '\n'
        
        # Query section
        query_prompt = 'This is the query sequence:\n'
        query_sequence = {
            'query_stock': query_stock,
            'query_date': query_date,
            'recent_date_list': query.get('recent_date_list', []),
            'adjusted_close_list': query.get('adjusted_close_list', [])
        }
        
        query_instruction = f'\nQuery: On {query_date}, the movement of ${query_stock} is [blank].\n'
        
        # Combine all parts
        full_prompt = instruction + retrieve_prompt + candidate_text + '\n' + query_prompt + str(query_sequence) + '\n' + query_instruction
        
        return full_prompt, candidate_movement_count
    
    def ask_llm(self, prompt: str) -> str:
        """Get prediction from current LLM"""
        if self.current_llm is None:
            raise RuntimeError("No LLM loaded")
        
        # Format prompt based on LLM architecture
        if 'Llama' in self.current_llm_name:
            messages = [
                {"role": "system", "content": "You are a financial analyst. Use the historical examples to predict stock movements accurately."},
                {"role": "user", "content": prompt}
            ]
            
            formatted_prompt = self.current_tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        else:
            formatted_prompt = f"System: You are a financial analyst.\nUser: {prompt}\nAssistant:"
        
        # Tokenize and generate
        input_ids = self.current_tokenizer.encode(
            formatted_prompt, 
            return_tensors="pt", 
            truncation=True, 
            max_length=1500  # Longer for retrieval context
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.current_llm.generate(
                input_ids,
                max_new_tokens=10,
                temperature=0.1,
                do_sample=False,
                pad_token_id=self.current_tokenizer.eos_token_id,
                eos_token_id=self.current_tokenizer.eos_token_id
            )
        
        generated_ids = outputs[0][input_ids.shape[1]:]
        response = self.current_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
        
        return response
    
    def extract_prediction(self, response: str, reference: str) -> Tuple[str, bool]:
        """Extract rise/fall prediction from LLM response"""
        response_clean = response.lower().strip()
        reference_clean = reference.lower().strip()
        
        if 'rise' in response_clean:
            prediction = 'rise'
        elif 'fall' in response_clean:
            prediction = 'fall'
        else:
            prediction = 'freeze'
        
        correct = (prediction == reference_clean)
        return prediction, correct
    
    def run_single_llm_experiment(self, llm_name: str, output_dir: str, k_candidates: int = 5) -> pd.DataFrame:
        """Run random retrieval experiment for a single LLM"""
        logger.info(f"🚀 Starting RANDOM RETRIEVAL experiment with {llm_name} (k={k_candidates})")
        
        # Load the specified LLM
        self.load_llm(llm_name)
        
        results = []
        processed_count = 0
        correct_count = 0
        
        for i, query in enumerate(self.test_queries):
            query_id = query['query_id']
            ground_truth = self.gt_lookup.get(query_id)
            
            if not ground_truth:
                continue
            
            reference_answer = ground_truth['actual_movement']
            
            # Skip freeze movements
            if reference_answer == 'freeze':
                continue
            
            # Progress logging
            if (i + 1) % 50 == 0:
                accuracy = correct_count / processed_count if processed_count > 0 else 0
                logger.info(f"Progress: {i+1}/{len(self.test_queries)} | Accuracy so far: {accuracy:.3f}")
            
            try:
                # Get qualified candidates (temporal safety)
                qualified_candidates = self.get_qualified_candidates(query['query_date'])
                
                if len(qualified_candidates) == 0:
                    logger.warning(f"No qualified candidates for query {query_id}")
                    continue
                
                # Randomly sample k candidates
                random_candidates = self.sample_random_candidates(qualified_candidates, k_candidates)
                
                # Generate prompt with random candidates
                prompt, candidate_movement_count = self.generate_random_retrieval_prompt(query, random_candidates)
                
                # Get LLM prediction
                llm_response = self.ask_llm(prompt)
                
                # Extract and validate prediction
                prediction, correct = self.extract_prediction(llm_response, reference_answer)
                
                # Track accuracy
                processed_count += 1
                if correct:
                    correct_count += 1
                
                # Store result
                result = {
                    'llm_name': llm_name,
                    'query_id': query_id,
                    'query_stock': query.get('query_stock', ''),
                    'query_date': query.get('query_date', ''),
                    'method': 'random_retrieval',
                    'prompt': prompt,
                    'llm_response': llm_response,
                    'prediction': prediction,
                    'reference': reference_answer,
                    'correct': correct,
                    'candidate_count': len(random_candidates),
                    'candidate_movement_dist': str(candidate_movement_count),
                    'k_candidates': k_candidates,
                    'total_qualified_candidates': len(qualified_candidates)
                }
                
                results.append(result)
                
            except Exception as e:
                logger.error(f"Error processing query {query_id}: {e}")
                continue
        
        # Create results DataFrame
        df = pd.DataFrame(results)
        
        # Save detailed results
        os.makedirs(output_dir, exist_ok=True)
        output_file = os.path.join(output_dir, f'{llm_name}_random_retrieval_k{k_candidates}_results.csv')
        df.to_csv(output_file, index=False)
        
        # Calculate metrics
        total_predictions = len(df)
        overall_accuracy = df['correct'].mean() if total_predictions > 0 else 0
        
        # Class-specific metrics
        rise_df = df[df['reference'] == 'rise']
        fall_df = df[df['reference'] == 'fall']
        
        rise_accuracy = rise_df['correct'].mean() if len(rise_df) > 0 else 0
        fall_accuracy = fall_df['correct'].mean() if len(fall_df) > 0 else 0
        
        # Candidate statistics
        avg_candidates_used = df['candidate_count'].mean() if total_predictions > 0 else 0
        
        # Log results
        logger.info("="*60)
        logger.info(f"RANDOM RETRIEVAL RESULTS - {llm_name}")
        logger.info("="*60)
        logger.info(f"Total Predictions: {total_predictions}")
        logger.info(f"Overall Accuracy: {overall_accuracy:.4f}")
        logger.info(f"Rise Accuracy: {rise_accuracy:.4f} ({len(rise_df)} samples)")
        logger.info(f"Fall Accuracy: {fall_accuracy:.4f} ({len(fall_df)} samples)")
        logger.info(f"Avg Candidates Used: {avg_candidates_used:.1f}")
        logger.info(f"Results saved to: {output_file}")
        logger.info("="*60)
        
        return df
    
    def run_multi_llm_experiment(self, llm_list: List[str], output_dir: str, k_candidates: int = 5) -> pd.DataFrame:
        """Run random retrieval experiment across multiple LLMs"""
        logger.info(f"🚀 Starting MULTI-LLM RANDOM RETRIEVAL experiment")
        logger.info(f"Testing LLMs: {llm_list}")
        logger.info(f"Candidates per query: {k_candidates}")
        
        all_results = []
        
        for llm_name in llm_list:
            logger.info(f"\n{'='*60}")
            logger.info(f"TESTING LLM: {llm_name}")
            logger.info(f"{'='*60}")
            
            try:
                # Run experiment for this LLM
                llm_results = self.run_single_llm_experiment(llm_name, output_dir, k_candidates)
                all_results.append(llm_results)
                
                logger.info(f"✅ Completed {llm_name}")
                
            except Exception as e:
                logger.error(f"❌ Failed {llm_name}: {e}")
                continue
        
        # Combine all results
        if all_results:
            combined_df = pd.concat(all_results, ignore_index=True)
            
            # Save combined results
            combined_file = os.path.join(output_dir, f'all_llms_random_retrieval_k{k_candidates}_combined.csv')
            combined_df.to_csv(combined_file, index=False)
            
            # Generate comparison report
            self._generate_comparison_report(combined_df, output_dir, k_candidates)
            
            logger.info(f"\n✅ Multi-LLM random retrieval experiment completed!")
            logger.info(f"Combined results saved to: {combined_file}")
            
            return combined_df
        else:
            logger.error("❌ No successful experiments!")
            return pd.DataFrame()
    
    def run_k_ablation_study(self, llm_name: str, output_dir: str, k_values: List[int] = [1, 3, 5, 10]) -> pd.DataFrame:
        """Run ablation study on number of random candidates"""
        logger.info(f"🔬 Starting K-ABLATION STUDY for {llm_name}")
        logger.info(f"Testing k values: {k_values}")
        
        all_results = []
        
        for k in k_values:
            logger.info(f"\n{'='*40}")
            logger.info(f"TESTING k={k} candidates")
            logger.info(f"{'='*40}")
            
            try:
                # Run experiment with k candidates
                k_results = self.run_single_llm_experiment(llm_name, output_dir, k)
                all_results.append(k_results)
                
                logger.info(f"✅ Completed k={k}")
                
            except Exception as e:
                logger.error(f"❌ Failed k={k}: {e}")
                continue
        
        # Analyze k-ablation results
        if all_results:
            self._analyze_k_ablation(all_results, llm_name, output_dir, k_values)
            
            combined_df = pd.concat(all_results, ignore_index=True)
            return combined_df
        else:
            logger.error("❌ K-ablation study failed!")
            return pd.DataFrame()
    
    def _analyze_k_ablation(self, results_list: List[pd.DataFrame], llm_name: str, output_dir: str, k_values: List[int]):
        """Analyze effect of different k values"""
        
        k_analysis = []
        
        for i, df in enumerate(results_list):
            k = k_values[i]
            
            accuracy = df['correct'].mean()
            rise_acc = df[df['reference'] == 'rise']['correct'].mean()
            fall_acc = df[df['reference'] == 'fall']['correct'].mean()
            
            k_analysis.append({
                'k_candidates': k,
                'overall_accuracy': accuracy,
                'rise_accuracy': rise_acc,
                'fall_accuracy': fall_acc,
                'total_predictions': len(df)
            })
        
        # Save k-ablation analysis
        k_df = pd.DataFrame(k_analysis)
        k_file = os.path.join(output_dir, f'{llm_name}_k_ablation_analysis.csv')
        k_df.to_csv(k_file, index=False)
        
        # Print analysis
        logger.info(f"\nK-ABLATION ANALYSIS - {llm_name}")
        logger.info("="*50)
        for result in k_analysis:
            logger.info(f"k={result['k_candidates']:2d} | Accuracy: {result['overall_accuracy']:.4f} | "
                       f"Rise: {result['rise_accuracy']:.3f} | Fall: {result['fall_accuracy']:.3f}")
        
        logger.info(f"\nK-ablation analysis saved to: {k_file}")
    
    def _generate_comparison_report(self, combined_df: pd.DataFrame, output_dir: str, k_candidates: int):
        """Generate comparison report across LLMs"""
        
        comparison_results = []
        
        for llm_name in combined_df['llm_name'].unique():
            llm_df = combined_df[combined_df['llm_name'] == llm_name]
            
            # Calculate metrics
            total = len(llm_df)
            accuracy = llm_df['correct'].mean()
            
            rise_df = llm_df[llm_df['reference'] == 'rise']
            fall_df = llm_df[llm_df['reference'] == 'fall']
            
            rise_acc = rise_df['correct'].mean() if len(rise_df) > 0 else 0
            fall_acc = fall_df['correct'].mean() if len(fall_df) > 0 else 0
            
            # Prediction distribution
            pred_dist = llm_df['prediction'].value_counts()
            
            # Candidate statistics
            avg_candidates = llm_df['candidate_count'].mean()
            
            comparison_results.append({
                'llm_name': llm_name,
                'total_predictions': total,
                'overall_accuracy': accuracy,
                'rise_accuracy': rise_acc,
                'fall_accuracy': fall_acc,
                'rise_predictions': len(rise_df),
                'fall_predictions': len(fall_df),
                'predicted_rise': pred_dist.get('rise', 0),
                'predicted_fall': pred_dist.get('fall', 0),
                'predicted_freeze': pred_dist.get('freeze', 0),
                'avg_candidates_used': avg_candidates,
                'k_candidates': k_candidates,
                'method': 'random_retrieval'
            })
        
        # Save comparison
        comparison_df = pd.DataFrame(comparison_results)
        comparison_file = os.path.join(output_dir, f'random_retrieval_k{k_candidates}_llm_comparison.csv')
        comparison_df.to_csv(comparison_file, index=False)
        
        # Print summary
        logger.info("\n" + "="*80)
        logger.info(f"RANDOM RETRIEVAL (k={k_candidates}) - LLM COMPARISON SUMMARY")
        logger.info("="*80)
        
        for result in comparison_results:
            logger.info(f"{result['llm_name']:15} | Accuracy: {result['overall_accuracy']:.4f} | "
                       f"Rise: {result['rise_accuracy']:.3f} | Fall: {result['fall_accuracy']:.3f} | "
                       f"Candidates: {result['avg_candidates_used']:.1f}")
        
        logger.info(f"\nDetailed comparison saved to: {comparison_file}")

# USAGE EXAMPLE
def run_random_retrieval_experiment():
    """Main function to run the random retrieval experiment"""
    
    # Configuration
    test_queries_file = '/root/nfs/AJ FinRag/Evaluation Results/Test Queries/test_queries_rise_fall_only.json'
    ground_truth_file = '/root/nfs/AJ FinRag/Evaluation Results/Test Queries/ground_truth_rise_fall_only.json'
    embeddings_dir = '/root/nfs/AJ FinRag/Embeddings/embeddings/test/FinQuest'
    output_dir = 'random_retrieval_experiments'
    
    # LLMs to test
    test_llms = [
        'StockLLM',      # Specialized financial model
        'Qwen2.5-1.5B',   # Small general model
        'Llama3.2-3B'    # Medium general model
    ]
    
    # Experimental parameters
    k_candidates = 5  # Number of random candidates to retrieve
    
    try:
        # Initialize experiment
        experiment = RandomRetrievalExperiment(test_queries_file, ground_truth_file, embeddings_dir)
        
        # Option 1: Run multi-LLM experiment with fixed k
        logger.info("🚀 RUNNING MULTI-LLM RANDOM RETRIEVAL EXPERIMENT")
        results = experiment.run_multi_llm_experiment(test_llms, output_dir, k_candidates)
        
        if len(results) > 0:
            logger.info("🎉 RANDOM RETRIEVAL EXPERIMENT COMPLETED SUCCESSFULLY!")
            logger.info(f"📊 Total results: {len(results)} predictions across {len(test_llms)} LLMs")
        
        # Option 2: Run k-ablation study (uncomment to run)
        # logger.info("🔬 RUNNING K-ABLATION STUDY")
        # ablation_results = experiment.run_k_ablation_study('StockLLM', output_dir, [1, 3, 5, 10])
        
    except Exception as e:
        logger.error(f"❌ Experiment failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    run_random_retrieval_experiment()

INFO:__main__:Loading candidates from 75 files...
INFO:__main__:Loaded 2642 test queries
INFO:__main__:Loaded 2642 ground truth entries
INFO:__main__:Loaded 8151 candidates for random sampling
INFO:__main__:🚀 RUNNING MULTI-LLM RANDOM RETRIEVAL EXPERIMENT
INFO:__main__:🚀 Starting MULTI-LLM RANDOM RETRIEVAL experiment
INFO:__main__:Testing LLMs: ['StockLLM', 'Qwen2.5-1.5B', 'Llama3.2-3B']
INFO:__main__:Candidates per query: 5
INFO:__main__:
INFO:__main__:TESTING LLM: StockLLM
INFO:__main__:🚀 Starting RANDOM RETRIEVAL experiment with StockLLM (k=5)
INFO:__main__:Loading StockLLM: ElsaShaw/StockLLM
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

INFO:__main__:✅ Successfully loaded StockLLM
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
INFO:__main__:Progress: 50/2642 | Accuracy so far: 0.510
INFO:__main__:Progress: 100/2642 | Accuracy so far: 0.455
INFO:__main__:Progress: 150/2642 | Accuracy so far: 0.470
INFO:__main__:Progress: 200/2642 | Accuracy so far: 0.472
INFO:__main__:Progress: 250/2642 | Accuracy so far: 0.490
INFO:__main__:Progress: 300/2642 | Accuracy so far: 0.482
INFO:__main__:Progress: 350/2642 | Accuracy so far: 0.493
INFO:__main__:Progress: 400/2642 | Accuracy so far: 0.499
INFO:__main__:Progress: 450/2642 | Accuracy so far: 0.503
INFO:__main__:Progress: 500/2642 | Accuracy so far: 0.513
INFO:__main__:Progress: 550/2642 | Accuracy so far: 0.508
INFO:__main__:Progress: 600/2642 | Accuracy so far: 0.503
INFO:__main__:Progress: 650/264

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

INFO:__main__:✅ Successfully loaded Llama3.2-3B
INFO:__main__:Progress: 50/2642 | Accuracy so far: 0.510
INFO:__main__:Progress: 100/2642 | Accuracy so far: 0.455
INFO:__main__:Progress: 150/2642 | Accuracy so far: 0.470
INFO:__main__:Progress: 200/2642 | Accuracy so far: 0.472
INFO:__main__:Progress: 250/2642 | Accuracy so far: 0.486
INFO:__main__:Progress: 300/2642 | Accuracy so far: 0.468
INFO:__main__:Progress: 350/2642 | Accuracy so far: 0.481
INFO:__main__:Progress: 400/2642 | Accuracy so far: 0.489
INFO:__main__:Progress: 450/2642 | Accuracy so far: 0.494
INFO:__main__:Progress: 500/2642 | Accuracy so far: 0.505
INFO:__main__:Progress: 550/2642 | Accuracy so far: 0.501
INFO:__main__:Progress: 600/2642 | Accuracy so far: 0.496
INFO:__main__:Progress: 650/2642 | Accuracy so far: 0.495
INFO:__main__:Progress: 700/2642 | Accuracy so far: 0.496
INFO:__main__:Progress: 750/2642 | Accuracy so far: 0.497
INFO:__main__:Progress: 800/2642 | Accuracy so far: 0.498
INFO:__main__:Progress: 8