In [1]:
import pickle

In [2]:
fp = open("/DATA/zhouyongyu/taac/train_95_7_input.txt", "rb+")
(train_creative, train_ad, train_advertiser, train_times, train_product, train_category, train_industry, train_mask, train_gender, train_age) = pickle.load(fp)
fp = open("/DATA/zhouyongyu/taac/test_95_7_input.txt", "rb+")
(test_creative, test_ad, test_advertiser, test_times, test_product, test_category, test_industry, test_mask, test_gender, test_age) = pickle.load(fp)

In [3]:
fp = open("/DATA/zhouyongyu/taac/weight_tensor_from_w2v300_w150.txt", "rb+")
creative_weight = pickle.load(fp)
fp = open("/DATA/zhouyongyu/taac/weight_tensor_ad300.txt", "rb+")
ad_weight = pickle.load(fp)
fp = open("/DATA/zhouyongyu/taac/weight_tensor_advertiser300.txt", "rb+")
advertiser_weight = pickle.load(fp)
fp = open("/DATA/zhouyongyu/taac/weight_tensor_times300.txt", "rb+")
times_weight = pickle.load(fp)
fp = open("/DATA/zhouyongyu/taac/weight_tensor_product300.txt", "rb+")
product_weight = pickle.load(fp)
fp = open("/DATA/zhouyongyu/taac/weight_tensor_category300.txt", "rb+")
category_weight = pickle.load(fp)
fp = open("/DATA/zhouyongyu/taac/weight_tensor_industry300.txt", "rb+")
industry_weight = pickle.load(fp)

In [4]:
import torch
creative_embedding = torch.nn.Embedding.from_pretrained(creative_weight)
creative_embedding.weight.requires_grad = False

ad_embedding = torch.nn.Embedding.from_pretrained(ad_weight)
ad_embedding.weight.requires_grad = False

advertiser_embedding = torch.nn.Embedding.from_pretrained(advertiser_weight)
advertiser_embedding.weight.requires_grad = False

times_embedding = torch.nn.Embedding.from_pretrained(times_weight)
times_embedding.weight.requires_grad = False

product_embedding = torch.nn.Embedding.from_pretrained(product_weight)
product_embedding.weight.requires_grad = False

category_embedding = torch.nn.Embedding.from_pretrained(category_weight)
category_embedding.weight.requires_grad = False

industry_embedding = torch.nn.Embedding.from_pretrained(industry_weight)
industry_embedding.weight.requires_grad = False


In [5]:
def generate_batch(creative, ad, advertiser, times, product, category, industry, mask, age, gender, batch_size):
    # 只是分成多个slices
    # slices中储存的为多个索引list（将索引切分为各个batch），如[[0, 1, 2], [3, 4, 5], ...]
#     creative = creative[:2000]
#     ad = ad[:2000]
#     advertiser = advertiser[:2000]
#     times = times[:2000]
#     product = product[:2000]
#     category = category[:2000]
#     industry = industry[:2000]
#     mask = mask[:2000]
#     age = age[:2000]
#     gender = gender[:2000]
    length = len(creative)
    n_batch = int(length / batch_size)
    if length % batch_size != 0:
        n_batch += 1
    # 等分成n_batch份, slices:n_batch*batch_size
    # 关于np.split()  https://www.jianshu.com/p/d020afd053bc
    slices = np.split(np.arange(n_batch * batch_size), n_batch)
    # 调整最后一份的length
    slices[-1] = slices[-1][:(length - batch_size * (n_batch - 1))]
    # slices中储存的为多个索引list（将索引切分为各个batch），如[[0, 1, 2], [3, 4, 5], ...]
    # slices是一个二维List，存放了一个batch里面的多个session
    return slices, np.array(creative), np.array(ad), np.array(advertiser), np.array(times), np.array(product), np.array(category), np.array(industry), np.array(mask), np.array(age), np.array(gender)

def get_slice(creative, ad, advertiser, times, product, category, industry, mask, age, gender, batch_index):
    # slice的形式 [[0, 1, 2], [3, 4, 5], ...]
    # 每一个slice就是一个batch，遍历slices得到的每一个batch_index就是一个数组，含有该batch内那几条数据的索引
    
    return creative[batch_index], ad[batch_index], advertiser[batch_index], times[batch_index], product[batch_index], category[batch_index], industry[batch_index], mask[batch_index], age[batch_index], gender[batch_index]

In [6]:
import torch.nn as nn
import torch

def trans_to_device(variable):
    if torch.cuda.is_available():
        return variable.cuda()
#         return variable
    else:
        return variable

class BiLSTM(nn.Module):
    """
    The RNN model that will be used to perform Sentiment analysis.
    """
 
    def __init__(self, output_size_age, output_size_gender, embedding_dim, hidden_dim, n_layers, bidirectional=True, drop_prob=0.5):
        """
        Initialize the model by setting up the layers.
        """
        super(BiLSTM, self).__init__()
 
        self.output_size_age = output_size_age
        self.output_size_gender = output_size_gender
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.bidirectional = bidirectional
        
        # embedding and LSTM layers
#         self.embedding = nn.Embedding(vocab_size, embedding_dim)
#         self.embedding = embedding
        self.lstm1 = nn.LSTM(embedding_dim, hidden_dim, n_layers, 
                            dropout=drop_prob, batch_first=True,
                            bidirectional=bidirectional)
        
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        self.dropout3 = nn.Dropout(0.5)
        
        # dropout layer
#         self.dropout1 = nn.Dropout(0.5)
        
        self.lstm2 = nn.LSTM(embedding_dim, hidden_dim, n_layers, 
                            dropout=drop_prob, batch_first=True,
                            bidirectional=bidirectional)
        
        # dropout layer
#         self.dropout2 = nn.Dropout(0.5)
        
        self.lstm3 = nn.LSTM(embedding_dim, hidden_dim, n_layers, 
                            dropout=drop_prob, batch_first=True,
                            bidirectional=bidirectional)
        
        self.lstm4 = nn.LSTM(embedding_dim, hidden_dim, n_layers, 
                            dropout=drop_prob, batch_first=True,
                            bidirectional=bidirectional)
        
        self.lstm5 = nn.LSTM(embedding_dim, hidden_dim, n_layers, 
                            dropout=drop_prob, batch_first=True,
                            bidirectional=bidirectional)
        
        self.lstm6 = nn.LSTM(embedding_dim, hidden_dim, n_layers, 
                            dropout=drop_prob, batch_first=True,
                            bidirectional=bidirectional)
        
        self.lstm7 = nn.LSTM(embedding_dim, hidden_dim, n_layers, 
                            dropout=drop_prob, batch_first=True,
                            bidirectional=bidirectional)
        
        # dropout layer
#         self.dropout3 = nn.Dropout(0.5)
        
        # linear and sigmoid layers
        if bidirectional:
            self.bn0 = nn.BatchNorm1d(hidden_dim*14, momentum=0.5)
            self.fc1 = nn.Linear(hidden_dim*6, 200)
            self.bn1 = nn.BatchNorm1d(hidden_dim*14, momentum=0.5)
#             self.fc2 = nn.Linear(200, 200)
            self.fc_age = nn.Linear(200, output_size_age)
            self.bn_age = nn.BatchNorm1d(output_size_age, momentum=0.5)
#             self.fc_gender = nn.Linear(hidden_dim*2, output_size_gender)
        else:
            self.bn0 = nn.BatchNorm1d(hidden_dim*7, momentum=0.5)
            self.fc1 = nn.Linear(hidden_dim*7, 400)
            
            self.bn1 = nn.BatchNorm1d(400, momentum=0.5)
            self.fc2 = nn.Linear(400, 200)
            
            self.bn2 = nn.BatchNorm1d(200, momentum=0.5)
#             self.fc2 = nn.Linear(200, 200)
            self.fc_age = nn.Linear(200, output_size_age)
            self.bn_age = nn.BatchNorm1d(output_size_age, momentum=0.5)
#             self.fc_gender = nn.Linear(hidden_dim, output_size_gender)
          
        self.sig = nn.Sigmoid()
#         print("==before")
        self.reset_parameters()
#         print("==after")
#         print("==emb")

        self.creative_embedding = creative_embedding
        self.ad_embedding = ad_embedding
        self.advertiser_embedding = advertiser_embedding
        self.times_embedding = times_embedding
        self.product_embedding = product_embedding
        self.category_embedding = category_embedding
        self.industry_embedding = industry_embedding
        
#         print(self.embedding)
        
 
    def forward(self, creative, ad, advertiser, times, product, category, industry, mask):
        creative_in = self.creative_embedding(creative)
        ad_in = self.ad_embedding(ad)
        advertiser_in = self.advertiser_embedding(advertiser)
        times_in = self.times_embedding(times)
        product_in = self.product_embedding(product)
        category_in = self.category_embedding(category)
        industry_in = self.industry_embedding(industry)
        
        # ==========pack_pad===========
        
        lengths = torch.sum(mask, dim=1).int().cpu().numpy().tolist()
        
        lengths = torch.tensor(lengths)
#         print(lengths)
        _, idx_sort = torch.sort(lengths, dim=0, descending=True)
        lengths = list(lengths[idx_sort])

        # lstm1
        creative_in = creative_in.index_select(0, trans_to_device(idx_sort))  # 按下标取元素 
        
#         pack = torch.nn.utils.rnn.pack_padded_sequence(hidden_in, lengths, batch_first=True, enforce_sorted=False)
        pack_creative = torch.nn.utils.rnn.pack_padded_sequence(creative_in, lengths, batch_first=True)
        number = 1
        if self.bidirectional:
            number = 2
        h0 = trans_to_device(torch.randn(self.n_layers*number, creative_in.shape[0], self.hidden_dim))
        c0 = trans_to_device(torch.randn(self.n_layers*number, creative_in.shape[0], self.hidden_dim))
        out_creative, _ = self.lstm1(pack_creative, (h0, c0))
        
        output_padded_creative, _ = torch.nn.utils.rnn.pad_packed_sequence(out_creative, batch_first=True)
        
        #还原tensor
        _, idx_unsort = torch.sort(idx_sort)
        output_padded_creative = output_padded_creative.index_select(0, trans_to_device(idx_unsort))
        

        final_out1 = torch.max(output_padded_creative, 1)[0]
        
         # lstm2
        ad_in = ad_in.index_select(0, trans_to_device(idx_sort))  # 按下标取元素 
        
#         pack = torch.nn.utils.rnn.pack_padded_sequence(hidden_in, lengths, batch_first=True, enforce_sorted=False)
        pack_ad = torch.nn.utils.rnn.pack_padded_sequence(ad_in, lengths, batch_first=True)
        number = 1
        if self.bidirectional:
            number = 2
        h0 = trans_to_device(torch.randn(self.n_layers*number, ad_in.shape[0], self.hidden_dim))
        c0 = trans_to_device(torch.randn(self.n_layers*number, ad_in.shape[0], self.hidden_dim))
        out_ad, _ = self.lstm2(pack_ad, (h0, c0))
        
        output_padded_ad, _ = torch.nn.utils.rnn.pad_packed_sequence(out_ad, batch_first=True)
        
        #还原tensor
        _, idx_unsort = torch.sort(idx_sort)
        output_padded_ad = output_padded_ad.index_select(0, trans_to_device(idx_unsort))
        

        final_out2 = torch.max(output_padded_ad, 1)[0]
        
         # lstm3
        advertiser_in = advertiser_in.index_select(0, trans_to_device(idx_sort))  # 按下标取元素 
        
#         pack = torch.nn.utils.rnn.pack_padded_sequence(hidden_in, lengths, batch_first=True, enforce_sorted=False)
        pack_advertiser = torch.nn.utils.rnn.pack_padded_sequence(advertiser_in, lengths, batch_first=True)
        number = 1
        if self.bidirectional:
            number = 2
        h0 = trans_to_device(torch.randn(self.n_layers*number, advertiser_in.shape[0], self.hidden_dim))
        c0 = trans_to_device(torch.randn(self.n_layers*number, advertiser_in.shape[0], self.hidden_dim))
        out_advertiser, _ = self.lstm3(pack_advertiser, (h0, c0))
        
        output_padded_advertiser, _ = torch.nn.utils.rnn.pad_packed_sequence(out_advertiser, batch_first=True)
        
        #还原tensor
        _, idx_unsort = torch.sort(idx_sort)
        output_padded_advertiser = output_padded_advertiser.index_select(0, trans_to_device(idx_unsort))
        

        final_out3 = torch.max(output_padded_advertiser, 1)[0]

         # lstm4
        times_in = times_in.index_select(0, trans_to_device(idx_sort))  # 按下标取元素 
        
#         pack = torch.nn.utils.rnn.pack_padded_sequence(hidden_in, lengths, batch_first=True, enforce_sorted=False)
        pack_times = torch.nn.utils.rnn.pack_padded_sequence(times_in, lengths, batch_first=True)
        number = 1
        if self.bidirectional:
            number = 2
        h0 = trans_to_device(torch.randn(self.n_layers*number, times_in.shape[0], self.hidden_dim))
        c0 = trans_to_device(torch.randn(self.n_layers*number, times_in.shape[0], self.hidden_dim))
        out_times, _ = self.lstm4(pack_times, (h0, c0))
        
        output_padded_times, _ = torch.nn.utils.rnn.pad_packed_sequence(out_times, batch_first=True)
        
        #还原tensor
        _, idx_unsort = torch.sort(idx_sort)
        output_padded_times = output_padded_times.index_select(0, trans_to_device(idx_unsort))
        

        final_out4 = torch.max(output_padded_times, 1)[0]

         # lstm5
        product_in = product_in.index_select(0, trans_to_device(idx_sort))  # 按下标取元素 
        
#         pack = torch.nn.utils.rnn.pack_padded_sequence(hidden_in, lengths, batch_first=True, enforce_sorted=False)
        pack_product = torch.nn.utils.rnn.pack_padded_sequence(product_in, lengths, batch_first=True)
        number = 1
        if self.bidirectional:
            number = 2
        h0 = trans_to_device(torch.randn(self.n_layers*number, product_in.shape[0], self.hidden_dim))
        c0 = trans_to_device(torch.randn(self.n_layers*number, product_in.shape[0], self.hidden_dim))
        out_product, _ = self.lstm5(pack_product, (h0, c0))
        
        output_padded_product, _ = torch.nn.utils.rnn.pad_packed_sequence(out_product, batch_first=True)
        
        #还原tensor
        _, idx_unsort = torch.sort(idx_sort)
        output_padded_product = output_padded_product.index_select(0, trans_to_device(idx_unsort))
        

        final_out5 = torch.max(output_padded_product, 1)[0]
        
         # lstm6
        category_in = category_in.index_select(0, trans_to_device(idx_sort))  # 按下标取元素 
        
#         pack = torch.nn.utils.rnn.pack_padded_sequence(hidden_in, lengths, batch_first=True, enforce_sorted=False)
        pack_category = torch.nn.utils.rnn.pack_padded_sequence(category_in, lengths, batch_first=True)
        number = 1
        if self.bidirectional:
            number = 2
        h0 = trans_to_device(torch.randn(self.n_layers*number, category_in.shape[0], self.hidden_dim))
        c0 = trans_to_device(torch.randn(self.n_layers*number, category_in.shape[0], self.hidden_dim))
        out_category, _ = self.lstm6(pack_category, (h0, c0))
        
        output_padded_category, _ = torch.nn.utils.rnn.pad_packed_sequence(out_category, batch_first=True)
        
        #还原tensor
        _, idx_unsort = torch.sort(idx_sort)
        output_padded_category = output_padded_category.index_select(0, trans_to_device(idx_unsort))
        

        final_out6 = torch.max(output_padded_category, 1)[0]
        
         # lstm7
        industry_in = industry_in.index_select(0, trans_to_device(idx_sort))  # 按下标取元素 
        
#         pack = torch.nn.utils.rnn.pack_padded_sequence(hidden_in, lengths, batch_first=True, enforce_sorted=False)
        pack_industry = torch.nn.utils.rnn.pack_padded_sequence(industry_in, lengths, batch_first=True)
        number = 1
        if self.bidirectional:
            number = 2
        h0 = trans_to_device(torch.randn(self.n_layers*number, industry_in.shape[0], self.hidden_dim))
        c0 = trans_to_device(torch.randn(self.n_layers*number, industry_in.shape[0], self.hidden_dim))
        out_industry, _ = self.lstm7(pack_industry, (h0, c0))
        
        output_padded_industry, _ = torch.nn.utils.rnn.pad_packed_sequence(out_industry, batch_first=True)
        
        #还原tensor
        _, idx_unsort = torch.sort(idx_sort)
        output_padded_industry = output_padded_industry.index_select(0, trans_to_device(idx_unsort))
        

        final_out7 = torch.max(output_padded_industry, 1)[0]
        
       
        # dropout and fully-connected layer
#         out1 = self.dropout3(final_out1)
#         out2 = self.dropout3(final_out2)
#         out3 = self.dropout3(final_out3)
#         out = final_out

#         out1 = self.dropout(final_out1)

        out = torch.cat([final_out1, final_out2, final_out3, final_out4, final_out5, final_out6, final_out7], 1)
    
        final_out1 = self.dropout1(out)
        final_out1 = self.bn0(final_out1)
        
        final_out1 = self.fc1(final_out1)
        final_out1 = self.bn1(final_out1)
        final_out1 = torch.relu(final_out1)
        final_out1 = self.dropout2(final_out1)
        
        final_out1 = self.fc2(final_out1)
        final_out1 = self.bn2(final_out1)
        final_out1 = torch.relu(final_out1)
        final_out1 = self.dropout3(final_out1)

        final_out1 = self.fc_age(final_out1)
        final_out1 = self.bn_age(final_out1)
        final_out1 = torch.tanh(final_out1)

        return final_out1
#         return torch.tanh(out_age)

    def reset_parameters(self):
        # print("调用")
        import math
        stdv = 1.0 / math.sqrt(self.hidden_dim)
        for weight in self.parameters():
            # print("===stdv===")
            # print(stdv)
#             print(weight)
            weight.data.uniform_(-stdv, stdv)

In [None]:
import numpy as np
from tqdm import tqdm

def train():
    torch.cuda.empty_cache()
    
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
    
    melo_batch_size = 512
    output_size_age = 10
    output_size_gender = 2
    embedding_dim = 300
    hidden_dim = 100
    n_layers = 2
    bidirectional = False  #这里为True，为双向LSTM
 
    model = trans_to_device(BiLSTM(output_size_age, output_size_gender, embedding_dim, hidden_dim, n_layers, bidirectional))
#     model = torch.load('/DATA/zhouyongyu/model_save/w2v32_age_model/age_initmodel_epoch_14.pkl')
#     model = trans_to_device(model)
    print(model)
    
    loss_function = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    best_age_acc = 0
    last_model_name = ''
    
    for epoch in tqdm(range(50)):
        model.train()
        slices, creative, ad, advertiser, times, product, category, industry, mask, age, gender = \
            generate_batch(train_creative, train_ad, train_advertiser, train_times, train_product, train_category, train_industry, train_mask, train_age, train_gender, batch_size=melo_batch_size)
        
        total_loss = 0.0
        count = 0
        for slice in slices:

            # print(slice)

            # time.sleep(5)

            optimizer.zero_grad()
            
            creative_melo, ad_melo, advertiser_melo, times_melo, product_melo, category_melo, industry_melo, mask_melo, age_melo, gender_melo = \
                get_slice(creative, ad, advertiser, times, product, category, industry, mask, age, gender, batch_index=slice)

#             print("======batch" + str(count + 1))
            count += 1
            
            creative_melo = trans_to_device(torch.Tensor(creative_melo).long())
            ad_melo = trans_to_device(torch.Tensor(ad_melo).long())
            advertiser_melo = trans_to_device(torch.Tensor(advertiser_melo).long())
            times_melo = trans_to_device(torch.Tensor(times_melo).long())
            product_melo = trans_to_device(torch.Tensor(product_melo).long())
            category_melo = trans_to_device(torch.Tensor(category_melo).long())
            industry_melo = trans_to_device(torch.Tensor(industry_melo).long())
            age_melo = trans_to_device(torch.Tensor(age_melo).long())
            gender_melo = trans_to_device(torch.Tensor(gender_melo).long())
            mask_melo = trans_to_device(torch.Tensor(mask_melo).float())
            
            out_age = model(creative_melo, ad_melo, advertiser_melo, times_melo, product_melo, category_melo, industry_melo, mask_melo)
            
#             print(out_age.size())
#             print(age_melo.size())
#             print(age_melo - 1)
            loss_age = loss_function(out_age, age_melo - 1)
#             print(out_gender.size())
#             print(gender_melo.size())
#             print(gender_melo - 1)
#             loss_gender = loss_function(out_gender, gender_melo - 1)
            
            loss = loss_age
            
            loss.backward()
            
            optimizer.step()
            
#             print('[%d/%d] Loss: %.4f' % (count, len(slices), loss.item()))
            
            total_loss += loss.item()
        
        print('\tLoss:\t%.3f' % total_loss)
        
        print("begin model.eval")
        model.eval()
        
        age_acc = 0
        gender_acc = 0
        
        slices, creative, ad, advertiser, times, product, category, industry, mask, age, gender = \
            generate_batch(test_creative, test_ad, test_advertiser, test_times, test_product, test_category, test_industry, test_mask, test_age, test_gender, batch_size=melo_batch_size)
        
        
        hit_age = []
        hit_gender = []
        
        for slice in slices:

            # print(slice)

            # time.sleep(5)

            optimizer.zero_grad()
            
            creative_melo, ad_melo, advertiser_melo, times_melo, product_melo, category_melo, industry_melo, mask_melo, age_melo, gender_melo = \
                get_slice(creative, ad, advertiser, times, product, category, industry, mask, age, gender, batch_index=slice)
            
#             print("======batch" + str(count + 1))
            count += 1
            
            creative_melo = trans_to_device(torch.Tensor(creative_melo).long())
            ad_melo = trans_to_device(torch.Tensor(ad_melo).long())
            advertiser_melo = trans_to_device(torch.Tensor(advertiser_melo).long())
            times_melo = trans_to_device(torch.Tensor(times_melo).long())
            product_melo = trans_to_device(torch.Tensor(product_melo).long())
            category_melo = trans_to_device(torch.Tensor(category_melo).long())
            industry_melo = trans_to_device(torch.Tensor(industry_melo).long())
            age_melo = trans_to_device(torch.Tensor(age_melo).long())
            gender_melo = trans_to_device(torch.Tensor(gender_melo).long())
            mask_melo = trans_to_device(torch.Tensor(mask_melo).float())
            
            out_age = model(creative_melo, ad_melo, advertiser_melo, times_melo, product_melo, category_melo, industry_melo, mask_melo)
            
            
            out_age = torch.nn.functional.log_softmax(out_age, dim=1)
#             out_gender = torch.nn.functional.log_softmax(out_gender, dim=1)

#             print(out_age)
            pred_age = out_age.argmax(1)
#             print(pred_age)
#             pred_gender = out_gender.argmax(1)

#             print(age_melo)
            correct_age = (pred_age == age_melo-1).sum()
#             print(correct_age)
#             correct_gender = (pred_gender == gender_melo-1).sum()
#             print(correct_gender)
            
            age_acc += correct_age.item()
#             print(age_acc)
#             gender_acc += correct_gender.item()
        
        print('\tage_acc:\t%.5f' % (float(age_acc)/len(creative)))
        
        if epoch == 0:
            best_age_acc = float(age_acc)/len(creative)
            print("save model")
            last_model_name = '1_copy_age_initmodel_7input_epoch_%s.pkl_%.5f'% (str(epoch), (float(age_acc)/len(creative)))
            torch.save(model, '/DATA/zhouyongyu/model_save/w2v32_age_model/' + last_model_name)
        else:
            if float(age_acc)/len(creative) >= best_age_acc:
                best_age_acc = float(age_acc)/len(creative)
                os.remove('/DATA/zhouyongyu/model_save/w2v32_age_model/' + last_model_name)
                last_model_name = '1_copy_age_initmodel_7input_epoch_%s.pkl_%.5f'% (str(epoch), (float(age_acc)/len(creative)))
                print("save model")
                torch.save(model, '/DATA/zhouyongyu/model_save/w2v32_age_model/' + last_model_name)
#         print('\tgender_acc:\t%.3f' % (float(gender_acc)/len(input)))

train()

In [4]:
import numpy as np
np_mask = np.array(train_mask)

In [5]:
print(np.sum(np_mask, axis=1))

[10  9 14 ... 13 89 12]


In [9]:
print(np.min(np.sum(np_mask, axis=1)))

3


In [42]:
a = torch.randn(3,4)

In [43]:
print(a)

tensor([[ 1.0113,  0.0845,  0.8785, -0.6062],
        [-0.7693, -1.1818,  0.1705, -1.2648],
        [-0.2886, -1.6722, -0.0205,  0.7305]])


In [44]:
print(a[0])

tensor([ 1.0113,  0.0845,  0.8785, -0.6062])


In [64]:
a=torch.tensor([[2, 3, 1],[2, 3, 4]])

In [65]:
print(a.argmax(1))

tensor([1, 2])


In [15]:
a =torch.Tensor([[4,1,2],[3,4,5]])
torch.max(a,1)[0].size()

torch.Size([2])