In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
from sklearn.model_selection import KFold

https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-use-k-fold-cross-validation-with-pytorch.md

In [2]:
RAMAN_DATASET_PATH = "./data/raman_diabetis_spectroscopy/"
XENON_DATASET_PATH = "./data/xenon/"

In [3]:
dset_name = 'vein'
d = pd.read_csv(RAMAN_DATASET_PATH+dset_name+'.csv')
d = d.iloc[1:,:]
y = d.iloc[:,1].astype(int)
X = d.iloc[:,800:1800]
means = X.mean(0).to_frame().T
means = means._append([means]*20, ignore_index=True).iloc[1:]
X = X - means

In [4]:
X_neg, y_neg = X[y==0], y[y==0]
X_pos, y_pos = X[y==1], y[y==1]


In [5]:
class MLP(nn.Module):
    def __init__(self):
        super.__init__()
        self.layers = nn.Sequential(
            nn.Linear(1000, 14),
            nn.ReLU(),
            nn.Linear(14, 2)
        )
    
    def forward(self, x):
        return self.layers(x)

In [24]:
n = 5
kf1 = KFold(n_splits=n, shuffle=True)
kf2 = KFold(n_splits=n, shuffle=True)
c = len(X_pos)
for train_pos, test_pos in kf1.split(X_pos):
    for train_neg, test_neg in kf2.split(X_neg):
        print("%s %s %s %s" % (train_pos, test_pos, train_neg+c, test_neg+c))

[ 1  2  4  5  6  7  9 10] [0 3 8] [11 12 13 15 16 17 19] [14 18]
[ 1  2  4  5  6  7  9 10] [0 3 8] [11 12 13 14 16 17 18] [15 19]
[ 1  2  4  5  6  7  9 10] [0 3 8] [12 13 14 15 17 18 19] [11 16]
[ 1  2  4  5  6  7  9 10] [0 3 8] [11 14 15 16 17 18 19] [12 13]
[ 1  2  4  5  6  7  9 10] [0 3 8] [11 12 13 14 15 16 18 19] [17]
[ 0  1  2  3  5  7  8  9 10] [4 6] [11 12 14 15 16 17 18] [13 19]
[ 0  1  2  3  5  7  8  9 10] [4 6] [11 12 13 14 16 18 19] [15 17]
[ 0  1  2  3  5  7  8  9 10] [4 6] [11 13 14 15 16 17 19] [12 18]
[ 0  1  2  3  5  7  8  9 10] [4 6] [11 12 13 15 17 18 19] [14 16]
[ 0  1  2  3  5  7  8  9 10] [4 6] [12 13 14 15 16 17 18 19] [11]
[ 0  2  3  4  5  6  7  8 10] [1 9] [11 12 13 14 15 16 19] [17 18]
[ 0  2  3  4  5  6  7  8 10] [1 9] [11 12 13 14 16 17 18] [15 19]
[ 0  2  3  4  5  6  7  8 10] [1 9] [12 14 15 16 17 18 19] [11 13]
[ 0  2  3  4  5  6  7  8 10] [1 9] [11 13 14 15 17 18 19] [12 16]
[ 0  2  3  4  5  6  7  8 10] [1 9] [11 12 13 15 16 17 18 19] [14]
[ 0  1  3  4  5