In [None]:
import logging
import random
import numpy as np
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import f1_score as f1_score_func
from tqdm import tqdm

def custom_train_test_split(X, y, test_size=0.2, random_state=None):
    classes, counts = np.unique(y, return_counts=True)
    if test_size == 0:
        return X, [], y, []
    # Find classes with only one or two instances
    small_classes = classes[counts < 5]

    # Separate out the instances of small classes
    large_class_mask = ~np.isin(y, small_classes)
    X_large = X[large_class_mask]
    y_large = y[large_class_mask]
    X_small = X[~large_class_mask]
    y_small = y[~large_class_mask]

    # Perform stratified split on the larger classes dataset
    X_train, X_test, y_train, y_test = train_test_split(
        X_large, y_large, test_size=test_size, random_state=random_state, stratify=y_large
    )

    # Randomly assign instances of small classes to training or testing sets
    for i in range(len(X_small)):
        if np.random.rand() < test_size:
            X_test = np.vstack([X_test, X_small[i]])
            y_test = np.hstack([y_test, y_small[i]])
        else:
            X_train = np.vstack([X_train, X_small[i]])
            y_train = np.hstack([y_train, y_small[i]])

    return X_train, X_test, y_train, y_test

#evaluation
def accuracy_per_class(predictions, true_vals):
    pred_flat = np.argmax(predictions, axis=1).flatten()
    labels_flat = true_vals.flatten()

    accuracy_dict = {}
    count_dict = {}

    for label in np.unique(labels_flat):
        y_preds = pred_flat[labels_flat == label]
        y_true = labels_flat[labels_flat == label]
        accuracy_dict[label] = np.sum(y_preds == y_true) / len(y_true) if len(y_true) > 0 else 0
        count_dict[label] = len(y_true)

    return accuracy_dict, count_dict


# Configure logging
logging.basicConfig(filename='Combined_top3_20times_training_log.txt', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()

class TQDMLoggingWrapper(tqdm):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = logger

    def display(self, msg=None, pos=None):
        if msg is not None:
            self.logger.info(msg)
        super().display(msg, pos)

    def update(self, n=1):
        super().update(n)
        desc = self.format_dict.get('desc', 'No description')
        postfix = self.format_dict.get('postfix', '')
        self.logger.info(f'{desc} - {postfix}')

    def set_description(self, desc=None, refresh=True):
        super().set_description(desc, refresh)
        if desc:
            self.logger.info(f'Set description: {desc}')


# Define the random seeds and other parameters
seed_values = list(range(2, 6, 2))
batch_size = 8
epochs = 5
learningrate = 1e-5

# Placeholder for accuracies
all_accuracies = {label: [] for label in range(len(top3label_dict))}

# Function to evaluate the model
def evaluate(dataloader_val):
    model.eval()
    loss_val_total = 0
    predictions, true_vals = [], []

    for batch in dataloader_val:
        batch = tuple(b.to(device) for b in batch)
        inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[2]}

        with torch.no_grad():
            outputs = model(**inputs)

        loss = outputs[0]
        logits = outputs[1]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)

    loss_val_avg = loss_val_total / len(dataloader_val)
    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)

    return loss_val_avg, predictions, true_vals

# Main loop over seed values
for seed_val in seed_values:
    # Set seeds
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

    # # Data preparation
    # X_train, X_val, y_train, y_val = custom_train_test_split(df.index.values, df.label.values, test_size=0.2, random_state=seed_val)
    # df['data_type'] = ['not_set'] * df.shape[0]
    # df.loc[X_train, 'data_type'] = 'train'
    # df.loc[X_val, 'data_type'] = 'val'
    # # logger.info(df.groupby(['soc_code', 'label', 'data_type']).count())
    
    # Perform train-test split on df1
    X_train_idx1, X_val_idx1, y_train1, y_val1 = custom_train_test_split(df1.index.values, df1.label.values, test_size=0.2, random_state=seed_val)
    
    # Perform train-test split on df2
    X_train_idx2, X_val_idx2, y_train2, y_val2 = custom_train_test_split(df2.index.values, df2.label.values, test_size=0.2, random_state=seed_val)
    
    # Combine the training indices and labels from df1 and df2
    X_train_combined = np.concatenate((X_train_idx1, X_train_idx2))
    y_train_combined = np.concatenate((y_train1, y_train2))
    
    # Combine the validation indices and labels from df1 and df2
    X_val_combined = np.concatenate((X_val_idx1, X_val_idx2))
    y_val_combined = np.concatenate((y_val1, y_val2))
    
    # Optionally, you can set the 'data_type' column for df1 and df2
    df1['data_type'] = 'not_set'
    df2['data_type'] = 'not_set'
    
    df1.loc[X_train_idx1, 'data_type'] = 'train'
    df1.loc[X_val_idx1, 'data_type'] = 'val'
    
    df2.loc[X_train_idx2, 'data_type'] = 'train'
    df2.loc[X_val_idx2, 'data_type'] = 'val'
    
    # If you want to combine df1 and df2 into a single dataframe:
    df = pd.concat([df1, df2])

    # Print the DataFrame with the 'data_type' column
    print("Combined DataFrame with 'data_type' column:\n", df)
        

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
    encoded_data_train = tokenizer.batch_encode_plus(
        df[df.data_type == 'train'].ade.values,
        add_special_tokens=True,
        return_attention_mask=True,
        pad_to_max_length=True,
        max_length=256,
        return_tensors='pt'
    )

    encoded_data_val = tokenizer.batch_encode_plus(
        df[df.data_type == 'val'].ade.values,
        add_special_tokens=True,
        return_attention_mask=True,
        pad_to_max_length=True,
        max_length=256,
        return_tensors='pt'
    )

    input_ids_train = encoded_data_train['input_ids']
    attention_masks_train = encoded_data_train['attention_mask']
    labels_train = torch.tensor(df[df.data_type == 'train'].label.values)

    input_ids_val = encoded_data_val['input_ids']
    attention_masks_val = encoded_data_val['attention_mask']
    labels_val = torch.tensor(df[df.data_type == 'val'].label.values)

    dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
    dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)

    model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=len(top3label_dict), output_attentions=False, output_hidden_states=False)

    dataloader_train = DataLoader(dataset_train, sampler=RandomSampler(dataset_train), batch_size=batch_size)
    dataloader_validation = DataLoader(dataset_val, sampler=SequentialSampler(dataset_val), batch_size=batch_size)

    optimizer = AdamW(model.parameters(), lr=learningrate, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(dataloader_train) * epochs)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    logger.info(f"Device used: {device}")

    # Training loop
    for epoch in TQDMLoggingWrapper(range(1, epochs+1), desc='Epoch Progress'):
        model.train()
        loss_train_total = 0

        progress_bar = TQDMLoggingWrapper(dataloader_train, desc=f'Epoch {epoch}', leave=False, disable=False)
        for batch in progress_bar:
            model.zero_grad()
            batch = tuple(b.to(device) for b in batch)
            inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[2]}

            outputs = model(**inputs)
            loss = outputs[0]
            loss_train_total += loss.item()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            progress_bar.set_postfix({'training_loss': f'{loss.item()/len(batch):.3f}'})

        # torch.save(model.state_dict(), f'./ADENorm_top3_epoch_{epoch}.model')

        logger.info(f'\nEpoch {epoch}')
        loss_train_avg = loss_train_total / len(dataloader_train)
        logger.info(f'Training loss: {loss_train_avg}')

        val_loss, predictions, true_vals = evaluate(dataloader_validation)
        val_f1 = f1_score_func(true_vals, np.argmax(predictions, axis=1), average='weighted')
        logger.info(f'Validation loss: {val_loss}')
        logger.info(f'F1 Score (Weighted): {val_f1}')

    _, predictions, true_vals = evaluate(dataloader_validation)
    accuracy_dict, count_dict = accuracy_per_class(predictions, true_vals)

    for label, accuracy in accuracy_dict.items():
        all_accuracies[label].append(accuracy)
    logger.info(f'Seed {seed_val} - Accuracy: {accuracy_dict} - Count: {count_dict}')

# Compute average and standard deviation of accuracy
avg_accuracy = {label: np.mean(accs) for label, accs in all_accuracies.items()}
std_accuracy = {label: np.std(accs) for label, accs in all_accuracies.items()}

# Save accuracies to file
with open('Combined_top3_20times_accuracies.txt', 'w') as f:
    # Write header for the accuracies
    f.write('Label\tSeed\tAccuracy\n')
    # Write individual accuracies
    for label in all_accuracies:
        accuracies = all_accuracies[label]
        for i, acc in enumerate(accuracies):
            f.write(f'{label}\tSeed_{i+1}\t{acc:.4f}\n')
    
    # Write the average and standard deviation
    f.write('\nLabel\tAverage Accuracy\tStandard Deviation\n')
    for label in all_accuracies:
        avg_acc = avg_accuracy.get(label, 'N/A')
        std_acc = std_accuracy.get(label, 'N/A')
        f.write(f'{label}\t{avg_acc:.4f}\t{std_acc:.4f}\n')

# Log the final results
logger.info('All accuracies: {}'.format(all_accuracies))
logger.info('Average Accuracy: {}'.format(avg_accuracy))
logger.info('Standard Deviation of Accuracy: {}'.format(std_accuracy))

print("Average Accuracy:", avg_accuracy)
print("Standard Deviation of Accuracy:", std_accuracy)
