<a href="https://colab.research.google.com/github/aakashak2000/LoRA_Mistral/blob/main/Mistral_Fine_Tune_Loan_Risk.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install --upgrade peft bitsandbytes accelerate sec-edgar-downloader evaluate wandb -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 MB[0m [31m36.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.2/23.2 MB[0m [31m99.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m116.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m99.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m56.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [19]:
# First, uninstall any conflicting versions
!pip uninstall -y peft accelerate -q

# Install compatible versions that work together
!pip install transformers
!pip install peft==0.7.1 -q
!pip install bitsandbytes==0.42.0 -q
!pip install accelerate==0.26.0 -q
!pip install datasets==2.14.6 -q
!pip install scipy -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 MB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [7]:
import os
import re
import json
import torch
import requests
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from bs4 import BeautifulSoup
from collections import Counter
from datetime import datetime, timedelta

import matplotlib.pyplot as plt
import seaborn as sns

from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer, DataCollatorForLanguageModeling

import warnings
warnings.filterwarnings('ignore')

In [8]:
def download_sec_data():

    try:
        url = "https://raw.githubusercontent.com/datasets/s-and-p-500-companies/master/data/constituents.csv"
        sp500_df = pd.read_csv(url)
        sample_companies = sp500_df
        sample_companies.to_csv("data/sp500.csv", index=False)
        return sample_companies

    except Exception as e:
        print(f"Error downloading SEC data: {e}")
        return None

def download_loan_data():
    os.makedirs('data', exist_ok=True)
    try:
        url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00350/default%20of%20credit%20card%20clients.xls'
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        raw_file = 'data/uci_credit_raw.xls'
        with open(raw_file, 'wb') as f:
            f.write(response.content)
        df = pd.read_excel(raw_file, header=1)
        loan_data = pd.DataFrame()
        loan_data['loan_amnt'] = df['LIMIT_BAL'].clip(1000, 40000)  # Cap at realistic loan amounts
        loan_data['term'] = np.where(
            df['LIMIT_BAL'] > 100000,
            '60 months',
            '36 months'
        )
        payment_delay = df['PAY_0'].fillna(0)
        base_rate = 8.0  # Base interest rate
        loan_data['int_rate'] = np.clip(
            base_rate + (payment_delay * 2.5) + np.random.normal(0, 1, len(df)),
            5.0, 30.0
        ).round(2)
        def assign_grade(row):
            score = 0

            if row['PAY_0'] <= 0:  # Paid on time or early
                score += 3
            elif row['PAY_0'] <= 2:  # 1-2 months delay
                score += 1
            # else: 0 points for 3+ months delay

            education = row.get('EDUCATION', 0)
            if education in [1, 2]:  # Graduate school, University
                score += 2
            elif education == 3:  # High school
                score += 1

            age = row.get('AGE', 30)
            if age >= 40:
                score += 1
            elif age >= 30:
                score += 0.5

            if row['LIMIT_BAL'] > 0:
                utilization = row.get('BILL_AMT1', 0) / row['LIMIT_BAL']
                if utilization < 0.3:
                    score += 1
                elif utilization > 0.8:
                    score -= 1

            if score >= 5:
                return 'A'
            elif score >= 4:
                return 'B'
            elif score >= 3:
                return 'C'
            elif score >= 2:
                return 'D'
            elif score >= 1:
                return 'E'
            elif score >= 0:
                return 'F'
            else:
                return 'G'

        loan_data['grade'] = df.apply(assign_grade, axis=1)

        age = df['AGE'].fillna(35)
        emp_length_map = []
        for a in age:
            if a < 25:
                emp_length_map.append(np.random.choice(['< 1 year', '1 year', '2 years'], p=[0.4, 0.3, 0.3]))
            elif a < 35:
                emp_length_map.append(np.random.choice(['2 years', '3 years', '4 years', '5 years'], p=[0.25, 0.25, 0.25, 0.25]))
            elif a < 45:
                emp_length_map.append(np.random.choice(['5 years', '6 years', '7 years', '8 years', '10+ years'], p=[0.2, 0.2, 0.2, 0.2, 0.2]))
            else:
                emp_length_map.append(np.random.choice(['8 years', '9 years', '10+ years'], p=[0.3, 0.3, 0.4]))

        loan_data['emp_length'] = emp_length_map

        base_income = df['LIMIT_BAL'] * 2.5  # Rough credit-to-income ratio

        education_multiplier = df['EDUCATION'].map({
            1: 1.4,  # Graduate school
            2: 1.2,  # University
            3: 1.0,  # High school
            4: 0.8,  # Others
            5: 0.8,  # Unknown
            6: 0.8   # Unknown
        }).fillna(1.0)

        loan_data['annual_inc'] = np.clip(
            base_income * education_multiplier + np.random.normal(0, 10000, len(df)),
            25000, 300000
        ).round(0)

        default_indicator = df['default payment next month'].fillna(0)

        loan_status_map = []
        for default in default_indicator:
            if default == 1:
                loan_status_map.append(np.random.choice([
                    'Charged Off', 'Late (31-120 days)', 'Late (16-30 days)'
                ], p=[0.7, 0.2, 0.1]))
            else:
                loan_status_map.append(np.random.choice([
                    'Fully Paid', 'Current'
                ], p=[0.8, 0.2]))

        loan_data['loan_status'] = loan_status_map

        output_file = 'data/sample_loans.csv'
        loan_data.to_csv(output_file, index=False)
        grade_dist = loan_data['grade'].value_counts().sort_index()
        for grade, count in grade_dist.items():
            pct = count / len(loan_data) * 100
            default_mask = loan_data['grade'] == grade
            default_rate = loan_data[default_mask]['loan_status'].isin(['Charged Off', 'Late (31-120 days)', 'Late (16-30 days)']).mean() * 100
        status_dist = loan_data['loan_status'].value_counts()
        for status, count in status_dist.items():
            pct = count / len(loan_data) * 100
        os.remove(raw_file)
        return loan_data

    except requests.exceptions.RequestException as e:
        return None

    except Exception as e:
        return None

def download_financial_news():

    companies_by_industry = {
        'Technology': {
            'companies': [
                ('Apple', 'Consumer Electronics'), ('Microsoft', 'Enterprise Software'),
                ('Google', 'Internet Services'), ('Meta', 'Social Media Platform'),
                ('NVIDIA', 'Semiconductors'), ('Intel', 'Semiconductors'),
                ('Salesforce', 'Cloud Software'), ('Oracle', 'Database Software'),
                ('IBM', 'Enterprise Technology'), ('Cisco', 'Networking Equipment')
            ],
            'headlines': {
                'Critical': [
                    "{company} faces antitrust investigation over market dominance",
                    "{company} data breach exposes millions of user accounts",
                    "{company} CEO arrested on securities fraud charges",
                    "{company} major product recall due to security vulnerabilities",
                    "{company} loses landmark patent lawsuit, faces $2B damages",
                    "{company} trading halted amid insider trading scandal",
                    "{company} faces federal investigation into data privacy violations"
                ],
                'High': [
                    "{company} hit with $500M fine for antitrust violations",
                    "{company} major cloud outage disrupts global operations",
                    "{company} loses key enterprise customers to competitors",
                    "{company} patent cliff threatens core revenue stream",
                    "{company} regulatory scrutiny over AI bias allegations",
                    "{company} cyber attack compromises intellectual property",
                    "{company} faces employee class action over workplace conditions"
                ],
                'Medium': [
                    "{company} quarterly earnings miss due to supply chain issues",
                    "{company} faces increased competition in cloud services",
                    "{company} chip shortage delays new product launches",
                    "{company} regulatory review of acquisition proposal",
                    "{company} employee departures raise talent retention concerns",
                    "{company} rising R&D costs pressure profit margins"
                ],
                'Low': [
                    "{company} reports strong cloud revenue growth",
                    "{company} launches innovative AI-powered platform",
                    "{company} wins major enterprise contract worth $1B",
                    "{company} patent portfolio strengthens competitive position",
                    "{company} successful product launch drives user adoption",
                    "{company} strategic partnership enhances capabilities"
                ]
            }
        },

        'Financial Services': {
            'companies': [
                ('JPMorgan Chase', 'Investment Banking'), ('Bank of America', 'Commercial Banking'),
                ('Goldman Sachs', 'Investment Banking'), ('Wells Fargo', 'Commercial Banking'),
                ('Morgan Stanley', 'Wealth Management'), ('Citigroup', 'Global Banking'),
                ('Visa', 'Payment Processing'), ('Mastercard', 'Payment Processing'),
                ('PayPal', 'Digital Payments'), ('American Express', 'Credit Cards')
            ],
            'headlines': {
                'Critical': [
                    "{company} faces money laundering investigation by federal authorities",
                    "{company} massive trading loss threatens capital adequacy",
                    "{company} executive indicted on embezzlement charges",
                    "{company} fails stress test, forced to suspend dividends",
                    "{company} regulatory license revoked in key jurisdiction"
                ],
                'High': [
                    "{company} credit loss provisions surge on loan defaults",
                    "{company} hit with $200M fine for compliance violations",
                    "{company} mortgage lending practices under regulatory review",
                    "{company} investment banking revenues plunge 40%",
                    "{company} credit rating downgraded on asset quality concerns"
                ],
                'Medium': [
                    "{company} net interest margin pressure from rate environment",
                    "{company} increased competition in digital banking space",
                    "{company} credit card charge-offs rise above expectations",
                    "{company} regulatory capital requirements increase",
                    "{company} technology modernization requires significant investment"
                ],
                'Low': [
                    "{company} strong loan growth drives revenue increase",
                    "{company} successful stress test results announced",
                    "{company} digital banking platform gains market share",
                    "{company} investment banking fees surge on M&A activity",
                    "{company} announces increased dividend and share buyback"
                ]
            }
        },

        'Healthcare': {
            'companies': [
                ('Johnson & Johnson', 'Pharmaceuticals'), ('Pfizer', 'Pharmaceuticals'),
                ('Moderna', 'Biotechnology'), ('AbbVie', 'Biotechnology'),
                ('Bristol Myers', 'Pharmaceuticals'), ('Merck', 'Pharmaceuticals'),
                ('Abbott', 'Medical Devices'), ('Gilead Sciences', 'Biotechnology'),
                ('Amgen', 'Biotechnology'), ('Biogen', 'Biotechnology')
            ],
            'headlines': {
                'Critical': [
                    "{company} FDA issues black box warning for key drug",
                    "{company} clinical trial halted due to safety concerns",
                    "{company} faces criminal charges over opioid marketing",
                    "{company} major drug recall due to contamination issues",
                    "{company} loses blockbuster drug patent, faces generic competition"
                ],
                'High': [
                    "{company} late-stage clinical trial fails to meet endpoints",
                    "{company} FDA rejects new drug application",
                    "{company} faces $1B lawsuit over drug side effects",
                    "{company} manufacturing facility shut down by regulators",
                    "{company} drug pricing investigation launched by Congress"
                ],
                'Medium': [
                    "{company} clinical trial results show mixed efficacy",
                    "{company} faces increased competition from biosimilars",
                    "{company} R&D costs rise as pipeline advances",
                    "{company} regulatory review delays drug approval timeline",
                    "{company} healthcare reform legislation threatens pricing"
                ],
                'Low': [
                    "{company} receives FDA breakthrough therapy designation",
                    "{company} positive Phase 3 trial results announced",
                    "{company} new drug launch exceeds sales expectations",
                    "{company} expands into oncology through strategic acquisition",
                    "{company} strong pipeline drives future growth prospects"
                ]
            }
        },

        'Energy': {
            'companies': [
                ('Exxon Mobil', 'Oil & Gas'), ('Chevron', 'Oil & Gas'),
                ('ConocoPhillips', 'Oil & Gas'), ('Marathon Oil', 'Oil & Gas'),
                ('Kinder Morgan', 'Pipeline'), ('Enbridge', 'Pipeline')
            ],
            'headlines': {
                'Critical': [
                    "{company} major oil spill creates environmental liability",
                    "{company} pipeline explosion causes environmental disaster",
                    "{company} faces criminal charges over emissions violations",
                    "{company} forced to shut down major refinery indefinitely"
                ],
                'High': [
                    "{company} commodity price crash impacts quarterly earnings",
                    "{company} environmental regulations force asset write-downs",
                    "{company} faces $2B fine for air quality violations",
                    "{company} hurricane damage shuts down Gulf operations"
                ],
                'Medium': [
                    "{company} oil price volatility pressures profit margins",
                    "{company} carbon emissions regulations increase costs",
                    "{company} faces pressure from ESG-focused investors",
                    "{company} supply chain disruption affects operations"
                ],
                'Low': [
                    "{company} major oil discovery boosts reserve estimates",
                    "{company} efficiency improvements reduce production costs",
                    "{company} renewable energy investment diversifies portfolio",
                    "{company} strong cash flow supports dividend increase"
                ]
            }
        },

        'Automotive': {
            'companies': [
                ('Tesla', 'Electric Vehicles'), ('Ford', 'Traditional Auto'),
                ('General Motors', 'Traditional Auto'), ('Stellantis', 'Traditional Auto')
            ],
            'headlines': {
                'Critical': [
                    "{company} massive vehicle recall due to fire risk",
                    "{company} autonomous driving system involved in fatal crash",
                    "{company} CEO faces fraud charges over production claims",
                    "{company} factory explosion halts production indefinitely"
                ],
                'High': [
                    "{company} NHTSA investigation into safety defects",
                    "{company} supply chain shortage delays vehicle deliveries",
                    "{company} loses market share to EV competitors",
                    "{company} faces union strike at major manufacturing plants"
                ],
                'Medium': [
                    "{company} quarterly deliveries miss analyst expectations",
                    "{company} semiconductor shortage impacts production",
                    "{company} raw material costs pressure profit margins",
                    "{company} increased competition in EV market segment"
                ],
                'Low': [
                    "{company} EV sales exceed growth projections",
                    "{company} breakthrough in battery technology announced",
                    "{company} receives major fleet order from enterprise customer",
                    "{company} autonomous driving milestone achieved"
                ]
            }
        },

        'Retail': {
            'companies': [
                ('Amazon', 'E-commerce'), ('Walmart', 'Discount Retail'),
                ('Target', 'General Retail'), ('Home Depot', 'Home Improvement'),
                ('Costco', 'Warehouse Club'), ('Starbucks', 'Food Service')
            ],
            'headlines': {
                'Critical': [
                    "{company} massive data breach exposes customer payment info",
                    "{company} faces federal investigation into labor practices",
                    "{company} product safety recall affects millions of items",
                    "{company} warehouse fire disrupts supply chain operations"
                ],
                'High': [
                    "{company} same-store sales decline for third consecutive quarter",
                    "{company} faces class action lawsuit over workplace conditions",
                    "{company} supply chain disruption causes widespread shortages",
                    "{company} loses key supplier relationship"
                ],
                'Medium': [
                    "{company} inventory levels rise as consumer demand softens",
                    "{company} increased competition pressures market share",
                    "{company} supply chain costs impact profit margins",
                    "{company} minimum wage increases affect labor costs"
                ],
                'Low': [
                    "{company} strong holiday sales drive revenue growth",
                    "{company} e-commerce platform expansion accelerates",
                    "{company} successful new store format improves margins",
                    "{company} digital transformation enhances customer experience"
                ]
            }
        }
    }

    news_data = []
    np.random.seed(42)

    headlines_per_industry = {
        'Critical': 8,
        'High': 25,
        'Medium': 40,
        'Low': 50
    }

    for industry, industry_data in companies_by_industry.items():
        companies = industry_data['companies']
        templates = industry_data['headlines']

        for risk_level, count in headlines_per_industry.items():
            risk_templates = templates[risk_level]

            for i in range(count):

                template = np.random.choice(risk_templates)
                company_name, business_type = companies[np.random.randint(0, len(companies))]

                headline = template.format(company=company_name)

                days_ago = np.random.randint(0, 730)
                date = datetime.now() - timedelta(days=days_ago)

                news_data.append({
                    'headline': headline,
                    'risk_level': risk_level,
                    'company': company_name,
                    'sector': industry,
                    'business_type': business_type,
                    'date': date.strftime('%Y-%m-%d'),
                    'source': np.random.choice(['Reuters', 'Bloomberg', 'MarketWatch', 'CNBC', 'WSJ'])
                })

    np.random.shuffle(news_data)
    df = pd.DataFrame(news_data)

    output_file = 'data/financial_news.csv'
    df.to_csv(output_file, index=False)

    all_headlines = ' '.join(df['headline'].str.lower())

    critical_keywords = ['investigation', 'scandal', 'fraud', 'recall', 'lawsuit', 'criminal', 'breach']
    high_keywords = ['fine', 'loss', 'decline', 'plunge', 'fails', 'suspension', 'downgrade']
    medium_keywords = ['miss', 'competition', 'restructuring', 'uncertainty', 'pressure']
    low_keywords = ['beats', 'partnership', 'approval', 'growth', 'innovation', 'award']

    for level, keywords in [('Critical', critical_keywords), ('High', high_keywords),
                           ('Medium', medium_keywords), ('Low', low_keywords)]:
        keyword_counts = {kw: all_headlines.count(kw) for kw in keywords}
        top_keywords = sorted(keyword_counts.items(), key=lambda x: x[1], reverse=True)[:3]

    return df

os.makedirs("data", exist_ok=True)
os.makedirs("models", exist_ok=True)
os.makedirs("results", exist_ok=True)

In [6]:
sec_data = download_sec_data()
loan_data = download_loan_data()
news_data = download_financial_news()
display(sec_data.head()), print(sec_data.shape)
display(loan_data.head()), print(loan_data.shape)
display(news_data.head()), print(news_data.shape)

Unnamed: 0,Symbol,Security,GICS Sector,GICS Sub-Industry,Headquarters Location,Date added,CIK,Founded
0,MMM,3M,Industrials,Industrial Conglomerates,"Saint Paul, Minnesota",1957-03-04,66740,1902
1,AOS,A. O. Smith,Industrials,Building Products,"Milwaukee, Wisconsin",2017-07-26,91142,1916
2,ABT,Abbott Laboratories,Health Care,Health Care Equipment,"North Chicago, Illinois",1957-03-04,1800,1888
3,ABBV,AbbVie,Health Care,Biotechnology,"North Chicago, Illinois",2012-12-31,1551152,2013 (1888)
4,ACN,Accenture,Information Technology,IT Consulting & Other Services,"Dublin, Ireland",2011-07-06,1467373,1989


(503, 8)


Unnamed: 0,loan_amnt,term,int_rate,grade,emp_length,annual_inc,loan_status
0,20000,36 months,13.65,B,< 1 year,62189.0,Charged Off
1,40000,60 months,7.78,A,2 years,300000.0,Charged Off
2,40000,36 months,8.43,A,5 years,260546.0,Current
3,40000,36 months,8.92,B,8 years,161601.0,Fully Paid
4,40000,36 months,6.21,A,9 years,169527.0,Fully Paid


(30000, 7)


Unnamed: 0,headline,risk_level,company,sector,business_type,date,source
0,Ford EV sales exceed growth projections,Low,Ford,Automotive,Traditional Auto,2023-09-10,Bloomberg
1,Oracle wins major enterprise contract worth $1B,Low,Oracle,Technology,Database Software,2023-07-31,Bloomberg
2,Amazon product safety recall affects millions ...,Critical,Amazon,Retail,E-commerce,2024-12-25,MarketWatch
3,Pfizer strong pipeline drives future growth pr...,Low,Pfizer,Healthcare,Pharmaceuticals,2025-02-12,MarketWatch
4,Ford breakthrough in battery technology announced,Low,Ford,Automotive,Traditional Auto,2024-03-18,WSJ


(738, 7)


(None, None)

In [9]:
def clean_financial_text(text):
    if pd.isna(text):
        return ""
    text = str(text).strip()
    financial_abbrevs = {
        r'\bSEC\b': 'Securities and Exchange Commission',
        r'\bFDA\b': 'Food and Drug Administration',
        r'\bIPO\b': 'Initial Public Offering',
        r'\bM&A\b': 'Mergers and Acquisitions',
        r'\bEPS\b': 'Earnings Per Share',
        r'\bROI\b': 'Return on Investment',
        r'\bEBITDA\b': 'Earnings Before Interest Tax Depreciation Amortization',
        r'\bCEO\b': 'Chief Executive Officer',
        r'\bCFO\b': 'Chief Financial Officer',
        r'\bR&D\b': 'Research and Development'
    }

    for abbrev, full_form in financial_abbrevs.items():
        text = re.sub(abbrev, full_form, text, flags=re.IGNORECASE)
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\w\s\$\%\-\.]', ' ', text)

    return text.strip()

def create_risk_indicators(text, financial_metrics=None):
    risk_indicators = {
        'text_signals': [],
        'severity_score': 0,
        'confidence': 0.0,
        'key_factors': []
    }
    critical_signals = [
        'investigation', 'fraud', 'bankruptcy', 'criminal', 'lawsuit',
        'breach', 'scandal', 'arrest', 'violation', 'default'
    ]

    high_signals = [
        'downgrade', 'layoff', 'recall', 'sanction', 'penalty',
        'loss', 'decline', 'warning', 'concern', 'risk'
    ]

    medium_signals = [
        'volatility', 'uncertainty', 'pressure', 'challenge',
        'competition', 'restructuring', 'review'
    ]

    low_signals = [
        'growth', 'profit', 'strong', 'increase', 'success',
        'partnership', 'approval', 'expansion', 'innovation'
    ]

    text_lower = text.lower()

    critical_count = sum(1 for signal in critical_signals if signal in text_lower)
    high_count = sum(1 for signal in high_signals if signal in text_lower)
    medium_count = sum(1 for signal in medium_signals if signal in text_lower)
    low_count = sum(1 for signal in low_signals if signal in text_lower)

    severity = (critical_count * 4) + (high_count * 3) + (medium_count * 2) + (low_count * 1)
    risk_indicators['severity_score'] = severity

    if critical_count > 0:
        risk_indicators['key_factors'].extend([s for s in critical_signals if s in text_lower])
    if high_count > 0:
        risk_indicators['key_factors'].extend([s for s in high_signals if s in text_lower])

    total_signals = critical_count + high_count + medium_count + low_count
    risk_indicators['confidence'] = min(1.0, total_signals / 3.0)  # Max confidence at 3+ signals

    return risk_indicators

news_data['cleaned_headline'] = news_data['headline'].apply(clean_financial_text)
news_data['risk_indicators'] = news_data['cleaned_headline'].apply(create_risk_indicators)

news_data.head()

Unnamed: 0,headline,risk_level,company,sector,business_type,date,source,cleaned_headline,risk_indicators
0,Ford EV sales exceed growth projections,Low,Ford,Automotive,Traditional Auto,2023-09-10,Bloomberg,Ford EV sales exceed growth projections,"{'text_signals': [], 'severity_score': 1, 'con..."
1,Oracle wins major enterprise contract worth $1B,Low,Oracle,Technology,Database Software,2023-07-31,Bloomberg,Oracle wins major enterprise contract worth $1B,"{'text_signals': [], 'severity_score': 0, 'con..."
2,Amazon product safety recall affects millions ...,Critical,Amazon,Retail,E-commerce,2024-12-25,MarketWatch,Amazon product safety recall affects millions ...,"{'text_signals': [], 'severity_score': 3, 'con..."
3,Pfizer strong pipeline drives future growth pr...,Low,Pfizer,Healthcare,Pharmaceuticals,2025-02-12,MarketWatch,Pfizer strong pipeline drives future growth pr...,"{'text_signals': [], 'severity_score': 2, 'con..."
4,Ford breakthrough in battery technology announced,Low,Ford,Automotive,Traditional Auto,2024-03-18,WSJ,Ford breakthrough in battery technology announced,"{'text_signals': [], 'severity_score': 0, 'con..."


In [10]:
news_training_examples = []

for _, row in news_data.iterrows():

    instruction = """
        Analyze the following financial news headline and classify the risk level. Respond with a JSON object containing:
        - "risk_level": one of ["Low", "Medium", "High", "Critical"]
        - "confidence": float between 0.0 and 1.0
        - "reasoning": brief explanation of the classification
        - "key_factors": list of main risk indicators
        Financial News:
    """
    input_text = row['cleaned_headline']
    expected_output = {
        "risk_level": row['risk_level'],
        "confidence": row['risk_indicators']['confidence'],
        "reasoning": f"Classification based on {row['sector']} industry patterns and risk indicators: {', '.join(row['risk_indicators']['key_factors'][:3])}",
        "key_factors": row['risk_indicators']['key_factors'][:3]
    }

    training_example = {
        "instruction": instruction,
        "input": input_text, #cleaned_headline
        "output": json.dumps(expected_output, indent=2),
        "source": "financial_news",
        "industry": row['sector'],
        "risk_level": row['risk_level']
    }

    news_training_examples.append(training_example)

def loan_to_risk_level(loan_status, grade):
    """Convert loan data to risk classification"""
    if loan_status in ['Charged Off', 'Default']:
        return 'Critical'
    elif loan_status in ['Late (31-120 days)', 'Late (16-30 days)'] or grade in ['F', 'G']:
        return 'High'
    elif grade in ['D', 'E'] or loan_status == 'In Grace Period':
        return 'Medium'
    else:
        return 'Low'

def create_loan_narrative(row):
    """Convert loan data row into descriptive text"""
    narrative = f"Loan application for ${row['loan_amnt']:,.0f} over {row['term']} "
    narrative += f"with {row['int_rate']:.1f}% interest rate. "
    narrative += f"Applicant has {row['emp_length']} employment history "
    narrative += f"and ${row['annual_inc']:,.0f} annual income. "
    narrative += f"Credit grade: {row['grade']}. "

    if row['loan_status'] == 'Fully Paid':
        narrative += "Loan was successfully repaid in full."
    elif row['loan_status'] == 'Charged Off':
        narrative += "Loan defaulted and was charged off."
    elif 'Late' in row['loan_status']:
        narrative += f"Current status: {row['loan_status']}."
    else:
        narrative += f"Current status: {row['loan_status']}."

    return narrative

loan_data['risk_level'] = loan_data.apply(
    lambda row: loan_to_risk_level(row['loan_status'], row['grade']), axis=1
)

loan_sample = loan_data.sample(n=min(1000, len(loan_data)), random_state=42)
loan_training_examples = []

for _, row in loan_sample.iterrows():

    # Create loan narrative
    loan_narrative = create_loan_narrative(row)

    instruction = """Analyze the following loan profile and assess the credit risk level. Respond with a JSON object containing:
    - "risk_level": one of ["Low", "Medium", "High", "Critical"]
    - "confidence": float between 0.0 and 1.0
    - "reasoning": explanation of risk assessment
    - "key_factors": list of main risk factors
    Loan Profile:

    """
    grade_risk_map = {'A': 'Low', 'B': 'Low', 'C': 'Medium', 'D': 'Medium', 'E': 'High', 'F': 'High', 'G': 'Critical'}
    expected_grade_risk = grade_risk_map.get(row['grade'], 'Medium')
    confidence = 0.9 if expected_grade_risk == row['risk_level'] else 0.7

    expected_output = {
        "risk_level": row['risk_level'],
        "confidence": confidence,
        "reasoning": f"Risk assessment based on credit grade {row['grade']}, employment history ({row['emp_length']}), and loan outcome ({row['loan_status']})",
        "key_factors": [
            f"credit_grade_{row['grade']}",
            f"loan_status_{row['loan_status'].replace(' ', '_')}",
            f"debt_to_income_ratio"
        ]
    }
    training_example = {
        "instruction": instruction,
        "input": loan_narrative,
        "output": json.dumps(expected_output, indent=2),
        "source": "loan_data",
        "industry": "Financial Services",
        "risk_level": row['risk_level']
    }
    loan_training_examples.append(training_example)


def create_company_risk_profile(row):
    """Create company risk profile from SEC data"""
    profile = f"Company: {row['Security']} ({row['Symbol']}) "
    profile += f"operates in {row['GICS Sector']} sector, "
    profile += f"specifically {row['GICS Sub-Industry']}. "
    profile += f"Founded in {row['Founded']}, "
    profile += f"headquartered in {row['Headquarters Location']}. "
    profile += f"Added to S&P 500 on {row['Date added']}. "

    # Add risk context based on our previous analysis
    try:
        company_age = 2025 - int(row['Founded'])
    except:
        founded_year = re.search(r'\d{4}', row['Founded'])
        if founded_year:
            company_age = 2025 - int(founded_year.group())
        else:
            company_age = 0


    if company_age < 20:
        profile += "Relatively young company with higher growth risk. "
    elif company_age > 100:
        profile += "Well-established company with stable operations. "

    return profile


if 'Overall_Risk' not in sec_data.columns:
    # Recreate risk analysis if not present
    def assign_sector_risk(sector, sub_industry):
        risk_map = {
            'Information Technology': 'Medium',
            'Financials': 'High',
            'Health Care': 'Medium',
            'Energy': 'High',
            'Consumer Discretionary': 'High',
            'Consumer Staples': 'Low',
            'Industrials': 'Medium',
            'Materials': 'High',
            'Communication Services': 'Medium',
            'Utilities': 'Low',
            'Real Estate': 'Medium'
        }
        return risk_map.get(sector, 'Medium')

    sec_data['Overall_Risk'] = sec_data.apply(
        lambda row: assign_sector_risk(row['GICS Sector'], row['GICS Sub-Industry']),
        axis=1
    )

sec_training_examples = []

for _, row in sec_data.iterrows():

    company_profile = create_company_risk_profile(row)

    instruction = """Analyze the following company profile and assess the business risk level based on industry, maturity, and market position. Respond with a JSON object containing:
- "risk_level": one of ["Low", "Medium", "High", "Critical"]
- "confidence": float between 0.0 and 1.0
- "reasoning": explanation of risk assessment
- "key_factors": list of main risk factors

Company Profile: """

    expected_output = {
        "risk_level": row['Overall_Risk'],
        "confidence": 0.8,
        "reasoning": f"Risk assessment based on {row['GICS Sector']} sector characteristics, company maturity (founded {row['Founded']}), and market position (S&P 500 member)",
        "key_factors": [
            f"sector_{row['GICS Sector'].replace(' ', '_').lower()}",
            f"sub_industry_{row['GICS Sub-Industry'].replace(' ', '_').lower()}",
            "sp500_member"
        ]
    }

    training_example = {
        "instruction": instruction,
        "input": company_profile,
        "output": json.dumps(expected_output, indent=2),
        "source": "sec_data",
        "industry": row['GICS Sector'],
        "risk_level": row['Overall_Risk']
    }

    sec_training_examples.append(training_example)

all_training_examples = news_training_examples + loan_training_examples + sec_training_examples

news_training_examples[0], loan_training_examples[0], sec_training_examples[0]

({'instruction': '\n        Analyze the following financial news headline and classify the risk level. Respond with a JSON object containing:\n        - "risk_level": one of ["Low", "Medium", "High", "Critical"]\n        - "confidence": float between 0.0 and 1.0\n        - "reasoning": brief explanation of the classification\n        - "key_factors": list of main risk indicators\n        Financial News: \n    ',
  'input': 'Ford EV sales exceed growth projections',
  'output': '{\n  "risk_level": "Low",\n  "confidence": 0.3333333333333333,\n  "reasoning": "Classification based on Automotive industry patterns and risk indicators: ",\n  "key_factors": []\n}',
  'source': 'financial_news',
  'industry': 'Automotive',
  'risk_level': 'Low'},
 {'instruction': 'Analyze the following loan profile and assess the credit risk level. Respond with a JSON object containing:\n    - "risk_level": one of ["Low", "Medium", "High", "Critical"]  \n    - "confidence": float between 0.0 and 1.0\n    - "rea

In [11]:
training_df = pd.DataFrame(all_training_examples)
train_examples, val_examples = train_test_split(
    all_training_examples,
    test_size=0.2,
    random_state=42,
    stratify=[ex['risk_level'] for ex in all_training_examples]
)
train_risk_dist = Counter([ex['risk_level'] for ex in train_examples])
val_risk_dist = Counter([ex['risk_level'] for ex in val_examples])

def format_for_mistral(example):

    # Mistral instruction format
    formatted = f"""<s>[INST] {example['instruction']}{example['input']} [/INST] {example['output']}</s>"""

    return {
        'text': formatted,
        'instruction': example['instruction'],
        'input': example['input'],
        'output': example['output'],
        'source': example['source'],
        'industry': example['industry'],
        'risk_level': example['risk_level']
    }

# Format training and validation sets
formatted_train = [format_for_mistral(ex) for ex in train_examples]
formatted_val = [format_for_mistral(ex) for ex in val_examples]
train_df = pd.DataFrame(formatted_train)
val_df = pd.DataFrame(formatted_val)

train_df.to_csv('data/train_dataset.csv', index=False)
val_df.to_csv('data/val_dataset.csv', index=False)


In [12]:
display(train_df.head(1))
print(train_df.loc[0, 'text'])
print()
print(train_df.loc[0, 'instruction'])
print()
print(train_df.loc[0, 'input'])
print()
print(train_df.loc[0, 'output'])

Unnamed: 0,text,instruction,input,output,source,industry,risk_level
0,<s>[INST] Analyze the following loan profile a...,Analyze the following loan profile and assess ...,"Loan application for $40,000 over 36 months wi...","{\n ""risk_level"": ""Critical"",\n ""confidence""...",loan_data,Financial Services,Critical


<s>[INST] Analyze the following loan profile and assess the credit risk level. Respond with a JSON object containing:
    - "risk_level": one of ["Low", "Medium", "High", "Critical"]  
    - "confidence": float between 0.0 and 1.0
    - "reasoning": explanation of risk assessment
    - "key_factors": list of main risk factors
    Loan Profile: 

    Loan application for $40,000 over 36 months with 12.6% interest rate. Applicant has 8 years employment history and $179,970 annual income. Credit grade: C. Loan defaulted and was charged off. [/INST] {
  "risk_level": "Critical",
  "confidence": 0.7,
  "reasoning": "Risk assessment based on credit grade C, employment history (8 years), and loan outcome (Charged Off)",
  "key_factors": [
    "credit_grade_C",
    "loan_status_Charged_Off",
    "debt_to_income_ratio"
  ]
}</s>

Analyze the following loan profile and assess the credit risk level. Respond with a JSON object containing:
    - "risk_level": one of ["Low", "Medium", "High", "Criti

In [13]:
def check_gpu_memory():
    if torch.cuda.is_available():
        device = torch.cuda.current_device()
        total_memory = torch.cuda.get_device_properties(device).total_memory / 1e9
        allocated_memory = torch.cuda.memory_allocated(device) / 1e9
        free_memory = total_memory - allocated_memory

        print(f"GPU: {torch.cuda.get_device_name(device)}")
        print(f"Total VRAM: {total_memory:.1f} GB")
        print(f"Free VRAM: {free_memory:.1f} GB")

        if free_memory < 10:
            print("Low GPU memory.")

        return True
    else:
        return False

gpu_available = check_gpu_memory()
if gpu_available:
    torch.cuda.empty_cache()
    print("GPU cache cleared")


GPU: NVIDIA A100-SXM4-40GB
Total VRAM: 42.5 GB
Free VRAM: 42.5 GB
GPU cache cleared


In [14]:
from huggingface_hub import login
login(new_session=False)


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [15]:
import transformers
transformers.__version__

'4.52.4'

In [17]:
# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
# model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", quantization_config = bnb_config,)

In [18]:
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
bnb_config = BitsAndBytesConfig(load_in_4bit=True,
                                bnb_4bit_use_double_quant=True,
                                bnb_4bit_quant_type="nf4",
                                bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(model_name,
                                          trust_remote_code=True,
                                          padding_size="left",
                                          add_eos_token=True,
                                          add_bos_token=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name,
                                             quantization_config = bnb_config,
                                             device_map = "auto",
                                             trust_remote_code=True,
                                             torch_dtype = torch.bfloat16)

if gpu_available:
    model_memory = torch.cuda.memory_allocated() / 1e9
    print(f"🧠 Model memory usage: {model_memory:.1f} GB")

ImportError: Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`

In [17]:
!pip install -U bitsandbytes



In [19]:
model = prepare_model_for_kbit_training(model)

In [20]:
lora_config = LoraConfig(r = 16,
                         target_modules = [
                             "q_proj",
                             "k_proj",
                             "v_proj",
                             "o_proj",
                             "gate_proj",
                             "up_proj",
                             "down_proj",
                         ],
                         lora_dropout=0.1,
                         bias="none",
                         task_type=TaskType.CAUSAL_LM)

model = get_peft_model(model, lora_config)

In [21]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()

    print(f"📊 Trainable params: {trainable_params:,} ({100 * trainable_params / all_param:.2f}%)")
    print(f"📊 All params: {all_param:,}")
print_trainable_parameters(model)

📊 Trainable params: 41,943,040 (1.11%)
📊 All params: 3,794,014,208


In [22]:
def tokenize_function(examples):
    tokenized = tokenizer(examples["text"],
                          truncation=True,
                          padding=False,
                          max_length=2048,
                          return_overflowing_tokens=False)
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)

tokenized_train = train_dataset.map(tokenize_function,
                                    batched=True,
                                    remove_columns=train_dataset.column_names,
                                    desc = "Tokenizing Training Data")
tokenized_val = val_dataset.map(tokenize_function,
                                    batched=True,
                                    remove_columns=train_dataset.column_names,
                                    desc = "Tokenizing Validation Data")

Tokenizing Training Data:   0%|          | 0/1792 [00:00<?, ? examples/s]

Tokenizing Validation Data:   0%|          | 0/449 [00:00<?, ? examples/s]

In [25]:
training_args = TrainingArguments(output_dir="./financial-risk-mistral",
                                  num_train_epochs=3,
                                  per_device_train_batch_size=1,
                                  per_device_eval_batch_size=1,
                                  gradient_accumulation_steps=4,
                                  learning_rate = 2e-4,
                                  weight_decay=0.001,
                                  lr_scheduler_type="cosine",
                                  warmup_ratio=0.1,
                                  dataloader_pin_memory=False,
                                  gradient_checkpointing=True,
                                  fp16=False,
                                  bf16=True,
                                  eval_strategy="steps",
                                  eval_steps=50,
                                  save_strategy="steps",
                                  save_steps=100,
                                  save_total_limit=2,
                                  logging_steps=10,
                                  logging_dir="./logs",
                                  report_to=None,
                                  load_best_model_at_end=True,
                                  metric_for_best_model="eval_loss",
                                  greater_is_better=False,
                                  remove_unused_columns=False,
                                  label_names=["labels"])

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    pad_to_multiple_of=8,
)


def extract_risk_level_from_output(text):

    try:
        if '{' in text and '}' in text:
            json_start = text.find('{')
            json_end = text.rfind('}') + 1
            json_text = text[json_start:json_end]
            parsed = json.loads(json_text)
            return parsed.get('risk_level', 'Unknown')
    except:
        pass

    text_lower = text.lower()
    if 'critical' in text_lower:
        return 'Critical'
    elif 'high' in text_lower:
        return 'High'
    elif 'medium' in text_lower:
        return 'Medium'
    elif 'low' in text_lower:
        return 'Low'

    return 'Unknown'


def compute_metrics(eval_pred):
    """Compute custom metrics for financial risk classification"""
    predictions, labels = eval_pred

    # Decode predictions
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Extract risk levels
    pred_risks = [extract_risk_level_from_output(pred) for pred in decoded_preds]
    true_risks = [extract_risk_level_from_output(label) for label in decoded_labels]

    # Calculate accuracy
    correct = sum(1 for p, t in zip(pred_risks, true_risks) if p == t and p != 'Unknown')
    total_valid = sum(1 for p in pred_risks if p != 'Unknown')

    accuracy = correct / total_valid if total_valid > 0 else 0.0

    # Calculate coverage (% of valid predictions)
    coverage = total_valid / len(pred_risks) if len(pred_risks) > 0 else 0.0

    return {
        'risk_accuracy': accuracy,
        'prediction_coverage': coverage,
        'valid_predictions': total_valid,
        'total_predictions': len(pred_risks)
    }


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [33]:
if gpu_available:
    torch.cuda.empty_cache()

# Record start time
start_time = datetime.now()
print(f"🕐 Training started at: {start_time.strftime('%H:%M:%S')}")

try:
    # Start training
    trainer.train()

    # Record end time
    end_time = datetime.now()
    training_duration = end_time - start_time

    print(f"\n✅ Training completed!")
    print(f"⏱️  Training duration: {training_duration}")
    print(f"🕐 Finished at: {end_time.strftime('%H:%M:%S')}")

except Exception as e:
    print(f"\n❌ Training failed: {str(e)}")
    print("💡 Try reducing batch size or sequence length if out of memory")
    raise


print("\n💾 Saving fine-tuned model...")

# Save the model and tokenizer
model.save_pretrained("./financial-risk-mistral-final")
tokenizer.save_pretrained("./financial-risk-mistral-final")

print("✅ Model saved to: ./financial-risk-mistral-final")


🕐 Training started at: 21:42:56

❌ Training failed: Accelerator.unwrap_model() got an unexpected keyword argument 'keep_torch_compile'
💡 Try reducing batch size or sequence length if out of memory


TypeError: Accelerator.unwrap_model() got an unexpected keyword argument 'keep_torch_compile'

In [31]:
train_results = trainer.state.log_history

# Extract metrics
train_losses = [log['train_loss'] for log in train_results if 'train_loss' in log]
eval_losses = [log['eval_loss'] for log in train_results if 'eval_loss' in log]
eval_accuracies = [log.get('eval_risk_accuracy', 0) for log in train_results if 'eval_risk_accuracy' in log]

print(f"📊 Training Summary:")
print(f"   Final training loss: {train_losses[-1]:.4f}" if train_losses else "   No training loss recorded")
print(f"   Final validation loss: {eval_losses[-1]:.4f}" if eval_losses else "   No validation loss recorded")
print(f"   Final risk accuracy: {eval_accuracies[-1]:.3f}" if eval_accuracies else "   No accuracy recorded")

# Plot training curves if we have data
if train_losses and eval_losses:
    plt.figure(figsize=(12, 4))

    # Loss curves
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss', color='blue')
    plt.plot(eval_losses, label='Validation Loss', color='red')
    plt.title('Training and Validation Loss')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Accuracy curve
    plt.subplot(1, 2, 2)
    if eval_accuracies:
        plt.plot(eval_accuracies, label='Risk Classification Accuracy', color='green')
        plt.title('Risk Classification Accuracy')
        plt.xlabel('Evaluation Steps')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True)

    plt.tight_layout()
    plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("📊 Training curves saved as: training_curves.png")

# =============================================================================
# STEP 12: Quick Model Test
# =============================================================================

print("\n🧪 Testing fine-tuned model...")

def test_model_prediction(text, model, tokenizer):
    """Test the fine-tuned model on a sample input"""

    # Format input like our training data
    instruction = """Analyze the following financial news headline and classify the risk level. Respond with a JSON object containing:
- "risk_level": one of ["Low", "Medium", "High", "Critical"]
- "confidence": float between 0.0 and 1.0
- "reasoning": brief explanation of the classification
- "key_factors": list of main risk indicators

Financial News: """

    full_input = f"<s>[INST] {instruction}{text} [/INST] "

    # Tokenize
    inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=2048)

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"].to(model.device),
            max_new_tokens=200,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    # Decode response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract just the model's response (after [/INST])
    if "[/INST]" in response:
        model_response = response.split("[/INST]")[-1].strip()
    else:
        model_response = response

    return model_response

# Test with sample headlines
test_cases = [
    "Apple reports record quarterly earnings beating all analyst expectations",
    "Tesla faces massive vehicle recall due to potential fire hazard",
    "JPMorgan hit with $200M fine for anti-money laundering violations",
    "Microsoft announces strategic AI partnership with leading tech companies"
]

print("🎯 Sample Predictions:")
print("-" * 50)

for i, test_text in enumerate(test_cases, 1):
    try:
        prediction = test_model_prediction(test_text, model, tokenizer)
        print(f"\nTest {i}: {test_text}")
        print(f"Prediction: {prediction}")
        print("-" * 50)
    except Exception as e:
        print(f"\nTest {i} failed: {str(e)}")

📊 Training Summary:
   No training loss recorded
   No validation loss recorded
   No accuracy recorded

🧪 Testing fine-tuned model...
🎯 Sample Predictions:
--------------------------------------------------

Test 1 failed: PeftModelForCausalLM.generate() takes 1 positional argument but 2 were given

Test 2 failed: PeftModelForCausalLM.generate() takes 1 positional argument but 2 were given

Test 3 failed: PeftModelForCausalLM.generate() takes 1 positional argument but 2 were given

Test 4 failed: PeftModelForCausalLM.generate() takes 1 positional argument but 2 were given


In [30]:
!pip install -q --upgrade accelerate transformers peft datasets
!pip install -q --upgrade torch torchaudio torchvision

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/411.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m411.1/411.1 kB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m38.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m821.2/821.2 MB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m393.1/393.1 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.9/8.9 MB[0m [31m132.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m79.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.7/897.7 kB[0m [31m61.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [29]:
import accelerate
accelerate.__version__

'0.26.0'