In [None]:
from pyspark.sql import SparkSession
spark = SparkSession.Builder().appName("Analysis").master("local[12]").getOrCreate()

In [None]:
df = spark.read.parquet("../test_data.parquet")
#df.collect()[0]

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import matplotlib as mpl
from pyspark.sql import functions as fn
from pyspark.sql import Row
from pyspark.sql.functions import col, lit
from datetime import datetime, timedelta
ASSUMED_DATE_TODAY=datetime(year=2023,month=1,day=1)

In [None]:
annuities = df.filter(df['Payment_Type'] == "Annuity")
linears = df.filter(df['Payment_Type'] == "Linear")
bullets = df.filter(df['Payment_Type'] == "Bullet")

In [None]:
linearpayments = linears.rdd.map(
    lambda x: (x['Id'], x['Notional'] / x['Term']))

linearpayments = linearpayments.toDF(["f_Id", "monthly_payment"]) 
#linearpayments.show()

In [None]:
bulletpayments = bullets.rdd.map(
    lambda x: (x["Id"], 0.0)
).toDF(["f_Id", "monthly_payment"])

In [None]:
def calc_annuity_payment(notional, interest, term):
    monthsTotal = term * 12
    r = interest / 12 / 100
    payAmount = (notional*r*(pow(1+r, monthsTotal))) / (pow(1+r, monthsTotal) -1)
    return payAmount

annuitypayments = annuities.rdd.map(
    lambda x: (x["Id"], calc_annuity_payment(x["Notional"], x["Interest_Rate"], x["Term"]) )
).toDF(["f_Id", "monthly_payment"])

In [None]:
#annuities.show()

In [None]:
annuities = annuities.join(annuitypayments, annuities['Id'] == annuitypayments['f_Id'], "inner")
linears = linears.join(linearpayments, linears['Id'] == linearpayments['f_Id'], "inner")
bullets = bullets.join(bulletpayments, bullets['Id'] == bulletpayments['f_Id'], "inner")
df_full = annuities.union(linears).union(bullets)

In [None]:
from pyspark.sql.types import StructType,StructField, StringType, DecimalType, IntegerType, DateType, FloatType
df_full = df_full.withColumn("end_date", fn.add_months(col("Start_Date"), col("Term")*12).cast(DateType()))

In [None]:
#create schema for the results table
from pyspark.sql.types import StructType,StructField, StringType, DecimalType, IntegerType, DateType, FloatType
schema = StructType([
  StructField('Id', StringType(), False),
  StructField('Interest_Rate', FloatType(), False),
  StructField('Reset_Frequency', IntegerType(), False),
    StructField('Remaining_Notional', FloatType(), False),
    StructField('Risk_Indicator', IntegerType(), False),
    StructField('Next_Reset_Date', DateType(), False),
    StructField('Date_Of_Payment', DateType(), False),
    StructField('monthly_payment', FloatType(), False),
    StructField('Repayment_Payment', FloatType(), False),
    StructField('Interest_Payment', FloatType(), False),
    StructField('Writeoff', FloatType(), False),
  ])

In [None]:
payment_projection = spark.createDataFrame([], schema)
payment_projection.printSchema()

In [None]:
import random

riskMigration = {
	0: (0.0, 1 - 0.001),
	1: (0.1, 1 - 0.01),
	2: (0.05, 1 - 0.01),
	3: (0.05, 1 - 0.05),
	4: (0.2, 1 - 0.1),
}
additionalInterestRatePerDuration = {
	30: 2.2,
	25: 1.9,
	20: 1.5,
	15: 1.0,
	10: 0.5,
	9: 0.4,
	7: 0.1,
	5: 0.0,

}
additionalInterestRatePerRiskCategory = {
	0: 0.0,
	1: 0.3 ,
	2: 1.1 ,
	3: 1.9 ,
	4: 3.5 ,
	5: 9.9 ,
}

def migrate_risk_category(old_category: int):
	draw = random.random()
	probabilities = riskMigration[old_category]
	if draw < probabilities[0]:
		return old_category - 1
	elif draw > probabilities[1]:
		return old_category + 1
	else:
		return old_category

def calc_one_step(original_row: Row, curr_date: datetime, T_Minus_one: Row):
	#fill in t minus one if we are in period 0
	if T_Minus_one is None:
		T_Minus_one = original_row
	#migrate the risk
	newRisk = migrate_risk_category(T_Minus_one['Risk_Indicator'])
	#check if we need to rest the interest rate, using month etc so we dont accidentily miss one if the day differs
	resetFrequency = T_Minus_one['Reset_Frequency']
	interestRate = T_Minus_one['Interest_Rate']
	resetDate = T_Minus_one['Next_Reset_Date']
	monthlyPayment = T_Minus_one['monthly_payment']
	remainingNotional = T_Minus_one['Remaining_Notional']
	if curr_date.month == T_Minus_one['Next_Reset_Date'].month and curr_date.day == T_Minus_one['Next_Reset_Date'].day:
		impliedBaseRate = original_row["Interest_Rate"] - additionalInterestRatePerDuration[original_row['Term']] - additionalInterestRatePerRiskCategory[original_row['Risk_Indicator']]
		yearsLeft = original_row['end_date'].year - original_row['Start_Date'].year
		reset_options = [30, 25, 20, 15, 10, 9, 7, 5]
		resetFrequency = min([x for x in reset_options if x >= yearsLeft])
		resetDate = datetime(year=resetDate.year + yearsLeft, month= resetDate.month, day= resetDate.day)
		#new reset frequency is implied to be the new duration
		interestRate = impliedBaseRate + additionalInterestRatePerDuration[resetFrequency] + additionalInterestRatePerRiskCategory[newRisk]
		if original_row['Payment_Type'] == 'Annuity':
		#re-calculate the annuity payments
			monthsTotal = resetFrequency * 12
			r = interestRate / 12 / 100
			monthlyPayment = (remainingNotional*r*(pow(1+r, monthsTotal))) / (pow(1+r, monthsTotal) -1)

	if newRisk == 5:
		interest = 0.0
		repayment = 0.0
		writeOff = T_Minus_one['Remaining_Notional']
	elif original_row['Payment_Type'] == 'Bullet':
		interest = (interestRate / 100 / 12) * T_Minus_one['Remaining_Notional']
		if curr_date.month == original_row['end_date'].month and curr_date.year == original_row['end_date'].year:
			repayment = T_Minus_one['Remaining_Notional']
		else:
			repayment = 0.0
		writeOff = 0.0
	elif original_row['Payment_Type'] == 'Linear':
		interest = (interestRate / 100 / 12) * T_Minus_one['Remaining_Notional']
		repayment = min(original_row['monthly_payment'], T_Minus_one['Remaining_Notional'])
		writeOff = 0.0
	elif original_row['Payment_Type'] == 'Annuity':
		interest = (interestRate / 100 / 12) * T_Minus_one['Remaining_Notional']
		repayment = min(original_row['monthly_payment'] - interest, T_Minus_one['Remaining_Notional'])
		writeOff = 0.0
	remainingNotional = T_Minus_one['Remaining_Notional'] - repayment - writeOff
	newrow = Row(
        	Id=original_row['Id'],
        	Interest_Rate=interestRate,
		Reset_Frequency=resetFrequency,
		Remaining_Notional=remainingNotional,
		Risk_Indicator=newRisk,
		Next_Reset_Date=resetDate,
		Date_Of_Payment=curr_date,
		monthly_payment=monthlyPayment,
		Repayment_Payment=repayment,
		Interest_Payment=interest,
		Writeoff=writeOff
	)
	return newrow


def calc_all_periods_for_row(row: Row):
	endDate = row["end_date"]
	endDate = datetime(endDate.year, endDate.month, endDate.day)
	curr_date = ASSUMED_DATE_TODAY
	listresults = []
	previous_period = None
	while curr_date < endDate:
		current_period = calc_one_step(row, curr_date, previous_period)
		listresults.append(current_period)
		previous_period = current_period
		curr_date += timedelta(days=31)
		curr_date -= (timedelta(days=curr_date.day-1))
		if current_period['Risk_Indicator'] == 5:
			break
        #above two lines should get the beginning of the month
	return listresults
        

In [None]:
#wtf is this copilot magic fuckery
copilot = df_full.rdd.map(calc_all_periods_for_row).flatMap(lambda x: x).toDF(schema)

In [None]:
copilot.count()

In [None]:
copilot.show()

In [None]:
copilot.write.parquet('./outcome.parquet', compression='snappy', mode='overwrite')

In [None]:
spark.stop()