# JBT Data Analysis

Extracts Judgement Bias Task (JBT) data from K-Limbic Software datafiles. Specficially for the testing phase where the column identifiers are different and there are extra possible tones. 

If you're running this then I assume you know what you're doing with Python and packages, etc..

Written by Peter Einarsson Nielsen (pe296) and edited by Olivia Stupart (osrps2)


In [1]:
from pathlib import Path
from dataclasses import dataclass
import dateutil
from typing import List, Optional
from enum import Enum
import csv
import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.tools.sm_exceptions import ConvergenceWarning
from datetime import datetime, timedelta
import matplotlib.pyplot as plt


class Side(Enum):
    left = 'L'
    right = 'R'


class DayTable(Enum):
    T4 = 'Training 4 Reward Magintude Training'
    T4a = 'Training 4 Reward Magnitude Training'
    T3 = 'Training 3'
    T2a = 'Training 2    2 KHz - 8 KHz'
    T2b = 'Training 2    8 KHz - 2 KHz'
    TT = 'Testing'


class ColumnIdx(Enum):
    outcome = 1
    tone = 2 #item index; 0: 2kHz (L), 1: 4.5kHz L, 2: 4.5kHz R, 3: 5kHz L, 4: 5kHz R, 5: 5.5kHz L, 6: 5.5kHz R, 7: 8kHz (R)
    
    s4entry = 11  # timestamp: stimulus presented
    s4exit = 12  # timestamp: lever touched

    s4L1 = 13 # left lever pressed
    s4L2 = 14 #right lever pressed

    s7 = 18  # correct reward - not used in final
    s8 = 19 #correct reward if dispense 2 - not used in final
    s9 = 20 #correct reward 1
    s10 = 21 #correct reward 3

    s12entry = 22  # timestamp: timeout
    s12exit = 23 # timestamp: timeout - always 10s
    s12L1 = 24 #premature left
    s12L2 = 25 # premature right

    s15entry = 36  # timestamp: ITI - always 5s
    s15exit = 37  # timestamp: ITI
    s15L1 = 38 # premature left
    s15L2 = 39 # premature right
    
    s16entry = 43 #timestamp: entry to premature time out 
    s16L1 = 45 # premature  left
    s16L2 = 46 # premature right
    
    #note, need to know where the correct side is defined as currently incorrect trials are not side differentiated this isn't 
    #said anywhere in the output file. For magnitude training, 50 trials are left and 50 trials are right every time. 

@dataclass
class ExperimentInfo():
    datetime: datetime
    subject_id: int
    box_id: int
    day_table: DayTable
    duration: int
    pellet_count: int


@dataclass
class TrialResult():
    tone: int
    #is_reversal: bool  # True if 'trial' is a reversal event. Otherwise False.
    choice_correct: Optional[bool] = None  # was their choice correct?
    chosen_side: Optional[Side] = None  # which side was chosen?
    correct_side: Optional[Side] = None  # which side was correct?
    #stuck_choice: Optional[bool] = None  # was their choice the same as for the last trial? (None for first run)
    reward_given: Optional[bool] = None  # were they given a reward?
    #reward_misleading: Optional[bool] = None  # was it a misleading reward?
    latency_choice: Optional[int] = None  # how long did it take to choose a side? (ms)
    #latency_collect: Optional[int] = None  # how long did it take to collect the reward?  (ms, None if no reward given)
    #latency_initiate: Optional[int] = None  # how long did it take to initiate the next trial? (ms, None if reward given)
    premature: Optional[int] = None #total number of prematures 
    missed: Optional[bool] = None # missed trials 

@dataclass
class ExperimentFindings():

    num_trials: int = 0
    #num_reversals: int = 0
    num_correct: int = 0
    #num_misleading_rewards: int = 0
    #num_misleading_loss: int = 0
    num_premature: Optional[int] = None
    num_left: int = 0
    num_right: int = 0
    num_left_correct: Optional[int] = None
    num_right_correct: Optional[int] = None
    num_missed: Optional[int] = None
    num_left_2: Optional[int] = None
    num_right_2: Optional[int] = None
    num_missed_2: Optional[int] = None
    num_left_45: Optional[int] = None
    num_right_45: Optional[int] = None
    num_missed_45: Optional[int] = None
    num_left_5: Optional[int] = None
    num_right_5: Optional[int] = None
    num_missed_5: Optional[int] = None
    num_left_55: Optional[int] = None
    num_right_55: Optional[int] = None
    num_missed_55: Optional[int] = None
    num_left_8: Optional[int] = None
    num_right_8: Optional[int] = None
    num_missed_8: Optional[int] = None
    num_learnt_corr: Optional[int] = None
    num_amb_corr: Optional[int] = None
    #num_trials_to_first_reversal: Optional[int] = None  # TODO
    #mean_perseverative_responses: Optional[float] = None  # TODO

    perc_correct: float = 0.0
    perc_left_correct: Optional[int] = None
    perc_right_correct: Optional[int] = None
    perc_learnt_correct: Optional[int] = None
    perc_amb_correct: Optional[int] = None
    perc_left_45: Optional [int] = None
    perc_left_5: Optional [int] = None
    perc_left_55: Optional [int] = None 
    perc_left_2: Optional [int] = None
    perc_left_8: Optional [int] = None
    #perc_misleading_rewards: float = 0.0
    #perc_misleading_loss: float = 0.0
    perc_premature: Optional[int] = None

    mean_latency_choice: float = 0.0
    mean_latency_choice_45: float = 0.0
    mean_latency_choice_5: float = 0.0
    mean_latency_choice_55: float = 0.0
    mean_latency_choice_learnt: float = 0.0    
    #mean_latency_collect: float = 0.0
    #mean_latency_initiate: float = 0.0

    

@dataclass
class Experiment():
    info: ExperimentInfo
    results: List[TrialResult]
    findings: ExperimentFindings

    def analyse(self):
        self.findings.num_trials = sum([1 for trial in self.results])
        self.findings.num_left = sum([1 for trial in self.results if trial.chosen_side == Side.left])
        self.findings.num_right = sum([1 for trial in self.results if trial.chosen_side == Side.right])
        self.findings.num_correct = sum([1 for trial in self.results if trial.choice_correct])
        self.findings.num_left_correct = sum([1 for trial in self.results if trial.choice_correct and trial.chosen_side == Side.left])
        self.findings.num_right_correct = sum([1 for trial in self.results if trial.choice_correct and trial.chosen_side == Side.right])
        self.findings.num_left_2 = sum([1 for trial in self.results if trial.tone == 0 and trial.chosen_side == Side.left])
        self.findings.num_left_45 = sum([1 for trial in self.results if trial.tone in [1,2] and trial.chosen_side == Side.left])
        self.findings.num_left_5 = sum([1 for trial in self.results if trial.tone in [3,4] and trial.chosen_side == Side.left])
        self.findings.num_left_55 = sum([1 for trial in self.results if trial.tone in [5,6] and trial.chosen_side == Side.left])
        self.findings.num_left_8 = sum([1 for trial in self.results if trial.tone == 7 and trial.chosen_side == Side.left])
        self.findings.num_right_2 = sum([1 for trial in self.results if trial.tone == 0 and trial.chosen_side == Side.right])
        self.findings.num_right_45 = sum([1 for trial in self.results if trial.tone in [1,2]  and trial.chosen_side == Side.right])
        self.findings.num_right_5 = sum([1 for trial in self.results if trial.tone in [3,4] and trial.chosen_side == Side.right])
        self.findings.num_right_55 = sum([1 for trial in self.results if trial.tone in [5,6] and trial.chosen_side == Side.right])
        self.findings.num_right_8 = sum([1 for trial in self.results if trial.tone == 7 and trial.chosen_side == Side.right])
        self.findings.num_missed_2 = sum([1 for trial in self.results if trial.tone == 0 and trial.missed])
        self.findings.num_missed_45 = sum([1 for trial in self.results if trial.tone in [1,2] and trial.missed])
        self.findings.num_missed_5 = sum([1 for trial in self.results if trial.tone in [3,4] and trial.missed])
        self.findings.num_missed_55 = sum([1 for trial in self.results if trial.tone in[5,6] and trial.missed])
        self.findings.num_missed_8 = sum([1 for trial in self.results if trial.tone == 7 and trial.missed])
        self.findings.num_learnt_corr = sum([1 for trial in self.results if trial.tone in [0,7] and trial.choice_correct])
        self.findings.num_amb_corr = sum([1 for trial in self.results if trial.tone in [1,2,3,4,5,6] and trial.choice_correct])
    
                                        
    
        #self.findings.num_misleading_rewards = sum([1 for trial in self.results if trial.reward_misleading])
        #self.findings.num_misleading_loss = sum([1 for trial in self.results if trial.choice_correct and trial.reward_given == False])
        self.findings.perc_correct = self.findings.num_correct / self.findings.num_trials
        self.findings.perc_learnt_correct = self.findings.num_learnt_corr / (sum([1 for trial in self.results if trial.tone in [0,7]]))
        self.findings.perc_amb_correct = self.findings.num_amb_corr / (sum([1 for trial in self.results if trial.tone in [1,2,3,4,5,6]]))
        self.findings.perc_left_correct = (self.findings.num_left_correct / 70) if self.findings.num_left > 0 else None
        self.findings.perc_right_correct = (self.findings.num_right_correct / 70) if self.findings.num_right > 0 else None
        self.findings.perc_left_45 = (self.findings.num_left_45 / (self.findings.num_left_45 + self.findings.num_right_45)) if self.findings.num_left_45 > 0 else None
        self.findings.perc_left_5 = (self.findings.num_left_5 / (self.findings.num_left_5 + self.findings.num_right_5)) if self.findings.num_left_5 > 0 else None
        self.findings.perc_left_55 = (self.findings.num_left_55 / (self.findings.num_left_55 + self.findings.num_right_55)) if self.findings.num_left_55 > 0 else None
        self.findings.perc_left_2 = (self.findings.num_left_2/ (self.findings.num_left_2 + self.findings.num_right_2)) if self.findings.num_left_2 > 0 else None
        self.findings.perc_left_8 = (self.findings.num_left_8 / (self.findings.num_left_8 + self.findings.num_right_8)) if self.findings.num_left_8 > 0 else None



        
        #self.findings.perc_misleading_rewards = self.findings.num_misleading_rewards / self.findings.num_trials
        #self.findings.perc_misleading_loss = self.findings.num_misleading_loss / self.findings.num_trials
        self.findings.mean_latency_choice = sum([trial.latency_choice for trial in self.results if trial.latency_choice]) / self.findings.num_trials
        self.findings.mean_latency_choice_45 = sum([trial.latency_choice for trial in self.results if trial.latency_choice and trial.tone in [1,2]]) / (self.findings.num_left_45 + self.findings.num_right_45)
        self.findings.mean_latency_choice_5 = sum([trial.latency_choice for trial in self.results if trial.latency_choice and trial.tone in [3,4]]) / (self.findings.num_left_5 + self.findings.num_right_5)
        self.findings.mean_latency_choice_55 = sum([trial.latency_choice for trial in self.results if trial.latency_choice and trial.tone in [5,6]]) / (self.findings.num_left_55 + self.findings.num_right_55)
        self.findings.mean_latency_choice_learnt = sum([trial.latency_choice for trial in self.results if trial.latency_choice and trial.tone in [1,7]]) / (self.findings.num_left_2 + self.findings.num_right_2 + self.findings.num_left_8 + self.findings.num_right_8)
        
        #self.findings.mean_latency_collect = sum([trial.latency_collect for trial in self.results if trial.latency_collect]) / self.findings.num_trials
        #self.findings.mean_latency_initiate = sum([trial.latency_initiate for trial in self.results if trial.latency_initiate]) / self.findings.num_trials
        self.findings.num_premature = sum([trial.premature for trial in self.results if trial.premature])
        self.findings.perc_premature = self.findings.num_premature / (self.findings.num_trials + self.findings.num_premature)
       
        self.findings.num_missed = sum([1 for trial in self.results if trial.missed])

        #self.findings.stwc = sum([1 for trial in self.results if trial.stuck_choice == True and trial.reward_given and trial.choice_correct])
        #self.findings.stwi = sum([1 for trial in self.results if trial.stuck_choice == True and trial.reward_given and trial.choice_correct == False])
        #self.findings.stlc = sum([1 for trial in self.results if trial.stuck_choice == True and trial.reward_given == False and trial.choice_correct])
        #self.findings.stli = sum([1 for trial in self.results if trial.stuck_choice == True and trial.reward_given == False and trial.choice_correct == False])
        #self.findings.shwc = sum([1 for trial in self.results if trial.stuck_choice == False and trial.reward_given and trial.choice_correct])
        #self.findings.shwi = sum([1 for trial in self.results if trial.stuck_choice == False and trial.reward_given and trial.choice_correct == False])
        #self.findings.shlc = sum([1 for trial in self.results if trial.stuck_choice == False and trial.reward_given == False and trial.choice_correct])
        #self.findings.shli = sum([1 for trial in self.results if trial.stuck_choice == False and trial.reward_given == False and trial.choice_correct == False])

        #self.findings.num_reversals = sum([1 for trial in self.results if trial.is_reversal])
        #self.findings.num_trials_to_first_reversal = next(i for i, trial in enumerate(self.results) if trial.is_reversal) if self.findings.num_reversals > 0 else None
        
        # Calculating the mean number of perseverative responses.
        #idx_of_reversals = [i for i, trial in enumerate(self.results) if trial.is_reversal]
        #num_persp_resps = []
        #for idx in idx_of_reversals:
         #   num_persp_resp = 0
          #  for trial in self.results[idx+1:]:
           #     if trial.choice_correct == False:
            #        num_persp_resp += 1
             #   else:
              #      break
           # if num_persp_resp == len(self.results[idx+1:]):
            #    continue
           # num_persp_resps.append(num_persp_resp)
        #self.findings.mean_perseverative_responses = sum(num_persp_resps) / len(num_persp_resps) if num_persp_resps else None
        

    def export_to_csv(self, file: Path):
        '''
        Output info and findings to csv file.
        If file exists: append to file.
        If file !exists: create file, write header, then write info and findings.
        '''

        if not file.is_file():
            # create file, add header
            header = [*vars(self.info), *vars(self.findings)]
            with open(file, 'w') as f:
                csv.writer(f).writerow(header)

        row = []
        for subobj in [self.info, self.findings]:
            for item in [*vars(subobj)]:
                if isinstance(getattr(subobj, item), DayTable):
                    row.append(f'{getattr(subobj, item).name}')
                    continue
                row.append(f'{getattr(subobj, item)}')

        with open(file, 'a') as f:
            csv.writer(f).writerow(row)
            





In [2]:
def get_runs(datafile):
    '''Identify all STARTDATA, ENDDATA chunks in a datafile.'''
    all_data = []

    with open(datafile, 'r') as ro:
        reader = csv.reader(ro)
        for row in reader:
            # if 'Ref' in row:
            #     print(f'{row=}')
            all_data.append(row)

    start_indices = [i for i, row in enumerate(all_data) if 'STARTDATA' in row]
    end_indices = [i for i, row in enumerate(all_data) if 'ENDDATA' in row]

    return [
        all_data[i:j] for i,j in zip(start_indices, end_indices)
    ]


In [3]:
def check_header_row(run: List, header: str):
    return [row[1] for row in run if row and row[0] == header][0]


def get_main_row_idx(run, search_term):
    return [i for i, row in enumerate(run) if search_term in row][0]


def get_run_info(run):
    header = run[1:get_main_row_idx(run, 'AC Comment')]

    # Extract pertinent header information
    return ExperimentInfo(
        datetime = dateutil.parser.parse(
            f"{check_header_row(header, 'Date')} {check_header_row(header, 'Time')}"
        ),
        subject_id = check_header_row(header, 'Subject Id'),
        box_id = check_header_row(header, 'Box Index'),
        day_table = DayTable(check_header_row(header, 'Day Table')),
        duration = check_header_row(header, 'Duration'),
        pellet_count = check_header_row(header, 'Pellet Count'),
    )



In [4]:
# Extract pertinent trial information

def get_trials(run):
    # Identify just the trials section of the datafile and convert the data to integers.
    _trials = run[get_main_row_idx(run, 'Stage (3)')+3:get_main_row_idx(run, 'ACTIVITYLOG')-1]
    # trials = [[int(el) for el in trial if el] for trial in trials]
    trials = []
    for line in _trials:
        if 'Ref' in line:
            continue
        if not line:
            continue
        trials.append([int(el) for el in line if el])
    # Remove 'test-is-ready' and incomplete trials.
    real_trials = [trial for trial in trials if trial[1] != 1000 and trial[1] != 128 and trial[1] != 150]
    # skip trials indicating end of run; trial[1]
    # == 1000 if ...
    # == 128 if run finishes prematurely
    # == 150 if run finishes (i.e. full 140 trials in run)
    return real_trials



def get_trial_info(trial: List, run_info: ExperimentInfo):
    #no reversals in JBT
    # If a 'trial' is a reversal event then there is nothing to analyse.
#     missed = True if trial[ColumnIdx.outcome.value] == 2 else False
    if trial[ColumnIdx.outcome.value] == 2:
        missed = True
    elif trial[ColumnIdx.outcome.value] == 150:
        missed = True
    else:
        missed = False 
    
    if missed:
        return TrialResult(
            missed = missed,
            tone = trial[ColumnIdx.tone.value]
        )
    
    tone = trial[ColumnIdx.tone.value]

    choice_correct = True if trial[ColumnIdx.outcome.value] == 0 else False
    
    

    if trial[ColumnIdx.s4L1.value] == 1 and trial[ColumnIdx.s4L2.value] == 0:
        chosen_side = Side.left
    elif trial[ColumnIdx.s4L1.value] == 0 and trial[ColumnIdx.s4L2.value] == 1:
        chosen_side = Side.right
    elif trial[ColumnIdx.s4L1.value] == 0 and trial[ColumnIdx.s4L2.value] == 0:
        missed = True
    #else:
     #   raise Exception('Miss Trial', trial)

    if choice_correct:
        correct_side = chosen_side
    elif not choice_correct and chosen_side == Side.left:
        correct_side = Side.right
    elif not choice_correct and chosen_side == Side.right:
        correct_side = Side.left

    #if prev_chosen_side:
     #   stuck_choice = True if chosen_side == prev_chosen_side else False
    #else:
     #   stuck_choice = None

    #if run_info.day_table not in [DayTable.T3, DayTable.PRL_R]:
     #   reward_given = True if choice_correct == True else False
      #  reward_misleading = False
    
    reward_given = True if trial[ColumnIdx.s7.value] == 1 or trial[ColumnIdx.s8.value] == 1 or trial[ColumnIdx.s9.value] == 1 or trial[ColumnIdx.s10.value] == 3 else False
     #   reward_misleading = True if trial[ColumnIdx.s14.value] == 1 else False 

    latency_choice = trial[ColumnIdx.s4exit.value] - trial[ColumnIdx.s4entry.value]

    #latency_collect = trial[ColumnIdx.s15exit.value] - trial[ColumnIdx.s15entry.value] if reward_given else None

    #latency_initiate = trial[ColumnIdx.s20entry.value] - trial[ColumnIdx.s19entry.value] if not reward_given else None
    
    premature = trial[ColumnIdx.s12L1.value] + trial[ColumnIdx.s12L2.value] + trial[ColumnIdx.s15L1.value] + trial[ColumnIdx.s15L2.value] + trial[ColumnIdx.s16L1.value] + trial[ColumnIdx.s16L2.value]

    return TrialResult(
        tone = tone,
        #is_reversal = is_reversal,
        choice_correct = choice_correct,
        chosen_side = chosen_side,
        correct_side = correct_side,
        #stuck_choice = stuck_choice,
        reward_given = reward_given,
        #reward_misleading = reward_misleading,
        latency_choice = latency_choice,
        #latency_collect = latency_collect,
        #latency_initiate = latency_initiate,
        premature = premature,
    )


In [5]:
# Put it all together

def get_experiments(datafile) -> List[Experiment]:

    runs = get_runs(datafile)

    print(f'NUMBER OF RUNS IN {datafile}: {len(runs)}')

    experiments = []

    for run in runs:
        run_info = get_run_info(run)  # ExperimentInfo
        if run_info.day_table in [DayTable.T2a,DayTable.T2b]:
            print('Run has daytable Touch Training: ignoring.')
            continue

        if run_info.day_table in [DayTable.T3,DayTable.T4, DayTable.T4a]:
            print('Run has daytable training: ignoring.')
            continue
        
        # num_reversals, real_trials = get_trials(run)
        real_trials = get_trials(run)

        trial_results = []

        previous_choice = None
        for trial in real_trials:
            trial_info = get_trial_info(trial, run_info)
            #previous_choice = trial_info.chosen_side
            trial_results.append(trial_info)

        # Ignore a run if the number of trials is zero.
        # This is to account for a particular issue where the participant did not complete the trial.
        if len(trial_results) == 0:
            print(f'Ignoring run. No trials in run: {run_info}')
            continue

        experiments.append(
            Experiment(
                info = run_info,
                results = trial_results,
                # findings = ExperimentFindings(num_reversals=num_reversals)
                findings = ExperimentFindings()
            )
        )

    return experiments


# Running the script
Each datafile contains multiple experiments.

get_experiments(df) parses a datafile and returns list of experiment objects.

Each experiment can be analysed by running exp.analyse() where exp is an experiment object.

In [6]:
#Large FINDINGS_FILE = Path('./JBT/findings_JBT.csv')
FINDINGS_FILE = Path('./Findings_B4.csv')
DATAFOLDER = Path('./B4-018')

datafiles = ([
    p for p in DATAFOLDER.iterdir() if p.is_file
    and p.suffix == '.csv' and 'Nov' in p.name
])




#DATAFOLDERS = [
 #    Path('./MS_Cohort_1/Animal_Data/Male/Corticosterone/'),
 #   Path('./MS_Cohort_1/Animal_Data/Female/Corticosterone/'),
# ]


#datafiles1 = ([
#     p for p in DATAFOLDERS[0].iterdir() if p.is_file
#     and p.suffix == '.csv'# and 'Combined' not in p.name
# ])

#datafiles2 = ([
#     p for p in DATAFOLDERS[1].iterdir() if p.is_file
#     and p.suffix == '.csv'# and 'Combined' not in p.name
# ])

#datafiles = datafiles_F + datafiles_M



print(len(datafiles))

experiments = [exp for df in datafiles for exp in get_experiments(df)]

print(f'TOTAL NUMBER OF EXPERIMENTS: {len(experiments)}')

for exp in experiments:
    try:
        exp.analyse()
    except ZeroDivisionError as err:
        raise Exception('Division by zero', exp.info)

    exp.export_to_csv(FINDINGS_FILE)


4
NUMBER OF RUNS IN B4-018\01-Nov-2023_003.csv: 8
NUMBER OF RUNS IN B4-018\02-Nov-2023_001.csv: 8
NUMBER OF RUNS IN B4-018\03-Nov-2023_001.csv: 8
NUMBER OF RUNS IN B4-018\06-Nov-2023_001.csv: 8
TOTAL NUMBER OF EXPERIMENTS: 32


In [19]:
##Then to label the baseline sessions as 123 and the stress sessions as 123
df = pd.read_csv('Test_findings_females.csv')

# Convert the 'datetime' column to datetime format if not already
df['datetime'] = pd.to_datetime(df['datetime'])

# Sort DataFrame by 'subject_id' and 'datetime'
df = df.sort_values(by=['subject_id', 'datetime'])

# Generate session labels
df['session_label'] = df.groupby('subject_id').cumcount() + 1

# You can check or limit the session labels to a maximum of 6 (if needed)
df['session_label'] = df['session_label'].apply(lambda x: x if x <= 6 else None)

df.to_csv('Test_findings_females.csv',index=False)

In [9]:
print(df)

               datetime  subject_id  box_id day_table  duration  pellet_count  \
0   2024-05-08 08:49:20           1       1        TT    213359           114   
40  2024-05-09 09:00:19           1       1        TT    175247           112   
80  2024-05-10 08:39:44           1       1        TT    250007            83   
140 2024-05-16 12:28:28           1       1        TT    289148            89   
168 2024-05-21 11:40:55           1       1        TT    281151            99   
..                  ...         ...     ...       ...       ...           ...   
79  2024-05-09 12:21:45          40       8        TT    186985           117   
119 2024-05-10 11:53:15          40       8        TT    184010           118   
139 2024-05-16 11:41:19          40       8        TT    184188           117   
167 2024-05-21 10:52:31          40       8        TT    228636            99   
207 2024-05-24 10:40:22          40       8        TT    177382           113   

     num_trials  num_correc

In [12]:
%matplotlib inline

import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.tools.sm_exceptions import ConvergenceWarning
from datetime import datetime, timedelta
import matplotlib.pyplot as plt

In [13]:
df = pd.read_csv(Path('./JBT/testing_findings_drugs1.csv'))
df = df.replace({None: np.nan})

# Define the subject_id values and their corresponding drug_group labels
subject_drug_map = {
    1: "A", 2: "B", 3:"C", 4:"E", 5:"B", 6: "A", 7: "D", 8:"C"
}

# Add the new column "drug_group" based on subject_id values
df["drug_group"] = df["subject_id"].map(subject_drug_map)

In [16]:
#Aim to add in the drug dosing 
df['datetime'] = pd.to_datetime(df['datetime'], format='%Y/%m/%d %H:%M')

dose_mapping = {
    "2023-11-10": {"A": "CDP", "B": "Cit", "C": "Veh", "D": "Prop", "E":"Aten"},
    "2023-11-14": {"A": "Cit", "B": "CDP", "C": "Prop", "D": "Aten", "E":"Veh"},
    "2023-11-17": {"A": "Veh", "B": "Aten", "C": "CDP", "D": "Cit", "E":"Prop"},
    "2023-11-21": {"A": "Prop", "B": "Veh", "C": "Aten", "D": "CDP", "E":"Cit"},
    "2023-11-24": {"A": "Aten", "B": "Prop", "C": "Cit", "D": "Veh", "E":"CDP"}
}
df["drug_dose"] = df.apply(lambda row: dose_mapping.get(str(row["datetime"].date()), {}).get(row["drug_group"]), axis=1)
#will test tomorrow 

df.to_csv(Path('./JBT/testing_findings_drugs1.csv'))