In [14]:
import numpy as np
import torch
import pandas as pd


dataframe = pd.read_csv('../../Dataset/taitanic.csv')


In [15]:
# 删除没有用的列并对其余的非数值列进行one-hot编码
def ProcessDataframe(dataframe):
    dataframe = dataframe.drop(['Name', 'Ticket', 'Cabin'], axis=1)
    dataframe['Sex'] = dataframe['Sex'].map({'female': 0, 'male': 1})
    dataframe['Embarked'] = dataframe['Embarked'].map({'S': 0, 'C': 1})
    return dataframe

In [16]:
dataframe = ProcessDataframe(dataframe)

print(dataframe.head(20))

    PassengerId  Survived  Pclass  Sex   Age  SibSp  Parch     Fare  Embarked
0             1         0       3    1  22.0      1      0   7.2500       0.0
1             2         1       1    0  38.0      1      0  71.2833       1.0
2             3         1       3    0  26.0      0      0   7.9250       0.0
3             4         1       1    0  35.0      1      0  53.1000       0.0
4             5         0       3    1  35.0      0      0   8.0500       0.0
5             6         0       3    1   NaN      0      0   8.4583       NaN
6             7         0       1    1  54.0      0      0  51.8625       0.0
7             8         0       3    1   2.0      3      1  21.0750       0.0
8             9         1       3    0  27.0      0      2  11.1333       0.0
9            10         1       2    0  14.0      1      0  30.0708       1.0
10           11         1       3    0   4.0      1      1  16.7000       0.0
11           12         1       1    0  58.0      0      0  26.5

In [17]:
# 这里由于只有age和embarked有缺失值，所以只对这两列进行填补
# 故对其进行了四舍五入
# 对于其他值，需要根据实际判断是否需要四舍五入，或者采用其他取整方式

# 均值填补
def MeanFill(dataframe):
    df_filled = dataframe.copy()
    for column in df_filled.select_dtypes(include=[np.number]).columns:
        mean_value = np.round(df_filled[column].mean())
        df_filled[column].fillna(mean_value, inplace=True)
    return df_filled


# 众数填补
def ModeFill(dataframe):
    df_filled = dataframe.copy()
    for column in df_filled.select_dtypes(include=[np.object]).columns:
        mode_value = np.round(df_filled[column].mode()[0])
        df_filled[column].fillna(mode_value, inplace=True)
    return df_filled


# knn填补
def KnnImpute(df, k=3):
    df_filled = df.copy()
    for column in df.columns:
        if df[column].isnull().any():
            not_null = df[~df[column].isnull()]
            is_null = df[df[column].isnull()]
            for idx in is_null.index:
                distances = np.linalg.norm(not_null.drop(columns=[column]).values - df.loc[idx].drop(column).values, axis=1)
                nearest_indices = not_null.index[np.argsort(distances)[:k]]
                knn_value = np.round(not_null.loc[nearest_indices, column].mean())
                df_filled.at[idx, column] = knn_value
    return df_filled

In [18]:
# dataframe = MeanFill(dataframe)
# dataframe = ModeFill(dataframe)
dataframe = KnnImpute(dataframe)

print(dataframe.head(20))

    PassengerId  Survived  Pclass  Sex   Age  SibSp  Parch     Fare  Embarked
0             1         0       3    1  22.0      1      0   7.2500       0.0
1             2         1       1    0  38.0      1      0  71.2833       1.0
2             3         1       3    0  26.0      0      0   7.9250       0.0
3             4         1       1    0  35.0      1      0  53.1000       0.0
4             5         0       3    1  35.0      0      0   8.0500       0.0
5             6         0       3    1  35.0      0      0   8.4583       0.0
6             7         0       1    1  54.0      0      0  51.8625       0.0
7             8         0       3    1   2.0      3      1  21.0750       0.0
8             9         1       3    0  27.0      0      2  11.1333       0.0
9            10         1       2    0  14.0      1      0  30.0708       1.0
10           11         1       3    0   4.0      1      1  16.7000       0.0
11           12         1       1    0  58.0      0      0  26.5

In [19]:
# Z-score标准化
def z_score_standardization(df):
    return (df - df.mean()) / df.std()


# Min-Max归一化
def min_max_normalization(df):
    return (df - df.min()) / (df.max() - df.min())


# L2归一化
def l2_normalization(df):
    return df.apply(lambda x: x / np.sqrt(np.sum(x**2)), axis=1)


# Max-Abs归一化
def max_abs_normalization(df):
    return df / df.abs().max()

In [20]:
dataframe = z_score_standardization(dataframe)
# dataframe = min_max_normalization(dataframe)
# dataframe = l2_normalization(dataframe)
# dataframe = max_abs_normalization(dataframe)

print(dataframe.head(20))

    PassengerId  Survived    Pclass       Sex       Age     SibSp     Parch  \
0     -1.729137 -0.788829  0.826913  0.737281 -0.601286  0.432550 -0.473408   
1     -1.725251  1.266279 -1.565228 -1.354813  0.588901  0.432550 -0.473408   
2     -1.721365  1.266279  0.826913 -1.354813 -0.303739 -0.474279 -0.473408   
3     -1.717480  1.266279 -1.565228 -1.354813  0.365741  0.432550 -0.473408   
4     -1.713594 -0.788829  0.826913  0.737281  0.365741 -0.474279 -0.473408   
5     -1.709708 -0.788829  0.826913  0.737281  0.365741 -0.474279 -0.473408   
6     -1.705823 -0.788829 -1.565228  0.737281  1.779088 -0.474279 -0.473408   
7     -1.701937 -0.788829  0.826913  0.737281 -2.089019  2.246209  0.767199   
8     -1.698051  1.266279  0.826913 -1.354813 -0.229352 -0.474279  2.007806   
9     -1.694165  1.266279 -0.369158 -1.354813 -1.196379  0.432550 -0.473408   
10    -1.690280  1.266279  0.826913 -1.354813 -1.940246  0.432550  0.767199   
11    -1.686394  1.266279 -1.565228 -1.354813  2.076

In [21]:
data_array = dataframe.values
data_tensor = torch.tensor(data_array, dtype=torch.float32)

In [22]:
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):
    def __init__(self, data_tensor):
        self.data = data_tensor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [23]:
dataset = MyDataset(data_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [24]:
for batch in dataloader:
    print("\nBatch:")
    print(batch)


Batch:
tensor([[-1.4416,  1.2663,  0.8269,  0.7373,  0.1426, -0.4743, -0.4734,  0.4888,
         -0.4853],
        [-1.5970, -0.7888, -1.5652,  0.7373, -0.1550,  0.4326, -0.4734,  1.0055,
          2.0583],
        [-1.1579, -0.7888,  0.8269, -1.3548, -1.5683,  1.3394,  2.0078,  0.0437,
         -0.4853],
        [ 0.9365, -0.7888,  0.8269,  0.7373, -1.1964,  3.1530,  0.7672,  0.1506,
         -0.4853],
        [-0.2137,  1.2663, -1.5652,  0.7373,  0.4401,  0.4326,  2.0078,  1.7667,
         -0.4853],
        [-0.8277, -0.7888, -0.3692,  0.7373,  2.1510, -0.4743, -0.4734, -0.3764,
         -0.4853],
        [-1.2085, -0.7888, -0.3692,  0.7373, -0.3781, -0.4743, -0.4734, -0.3865,
         -0.4853],
        [ 0.2370,  1.2663, -0.3692, -1.3548,  0.2170, -0.4743,  2.0078, -0.1248,
         -0.4853],
        [-0.2798, -0.7888, -1.5652,  0.7373, -0.6013, -0.4743, -0.4734,  2.0813,
          2.0583],
        [-1.1735, -0.7888,  0.8269,  0.7373, -0.8244, -0.4743, -0.4734, -0.5122,
         -0