In [60]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader,TensorDataset
import re

torch.manual_seed(1)

<torch._C.Generator at 0x7fbd7be79db0>

In [83]:
with open('data/mindata.txt','r') as f:
    txt = "".join(f.readlines()).split('\n')
txt = txt[:-1]

In [90]:
def find_pos(text):
    st = re.search(r"\(",text).span()[0]+1
    end = re.search(r"\)",text).span()[0]
    endpoint = re.search(r"\.",text).span()[0]
    return st,end,endpoint
labels = []
sentences = []
for doc in txt:
    st,end,endpoint = find_pos(doc)
    sentences.append(doc[:endpoint])
    if doc[st:end]=='NS': #NS:1 PS:0
        labels.append(1)
    else:
        labels.append(0)
print(labels)
print(sentences)

[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
['Peter is running', 'He is not walking', 'We should tell the truth', 'We should never tell lies', 'Everyone is in the garden', 'There is no one in the house', 'The fridge is empty', 'There is nothing in it', 'It is very cloudy', 'It isn’t sunny', 'I have sold the last newspaper', 'I have no newspapers left', 'Someone has eaten all the cookies', 'There are none in the bag']


In [101]:
allwords = set(" ".join(sentences).split())
voc_size = len(allwords)+1

In [103]:
word_to_index = {word:index+1 for index,word in enumerate(allwords)}
word_to_index["<UNK>"] = 0
index_to_word = {index:word for word,index in word_to_index.items()}


In [110]:
sentence_to_index = []
for sentence in sentences:
    sentence_to_index.append([word_to_index[word] for word in sentence.split()])
    
max_len = max([len(sentence) for sentence in sentence_to_index])
max_len

7

In [113]:
def padding_sentence(sentence,padding_size):
    cur_len = len(sentence)
    return [0]*(padding_size-cur_len)+sentence

for i in range(len(sentences)):
    sentence_to_index[i] = padding_sentence(sentence_to_index[i],max_len)

sentence_to_index

[[0, 0, 0, 0, 4, 17, 6],
 [0, 0, 0, 21, 17, 32, 38],
 [0, 0, 3, 7, 19, 14, 20],
 [0, 0, 3, 7, 9, 19, 37],
 [0, 0, 10, 17, 45, 14, 13],
 [18, 17, 24, 36, 45, 14, 33],
 [0, 0, 0, 26, 15, 17, 16],
 [0, 0, 18, 17, 43, 45, 12],
 [0, 0, 0, 34, 17, 28, 23],
 [0, 0, 0, 0, 34, 5, 41],
 [0, 8, 22, 1, 14, 11, 29],
 [0, 0, 8, 22, 24, 40, 30],
 [0, 42, 27, 39, 35, 14, 25],
 [0, 18, 31, 44, 45, 14, 2]]

In [114]:
assert len(sentence_to_index)==len(labels)

In [115]:
sentence_to_index = torch.LongTensor(sentence_to_index)
labels = torch.LongTensor(labels)
datasets = TensorDataset(sentence_to_index,labels)
dataloader = DataLoader(datasets,batch_size=128,shuffle=True)


In [230]:
class RNN(nn.Module):
    def __init__(self,vocab_size,emb_size,hidden_size,n_layers,out_size):
        super().__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.emb = nn.Embedding(vocab_size,emb_size,padding_idx=0)
        self.rnn = nn.RNN(emb_size,hidden_size,num_layers=n_layers,batch_first=True)
        self.fc = nn.Linear(hidden_size,out_size)
    
    def forward(self,x):
        emb = self.emb(x) #batch_size*sequence_size*embed_size
        rnn_out,h = self.rnn(emb) #batch_size*sequence_size*hidden_size
        h = h.squeeze()
        output = torch.sigmoid(self.fc(h))
    
        return output
        

In [231]:
emb_size = 8
hidden_size = 16
n_layers = 1
learning_rate = 0.01
out_size = 1
epochs = 2000

model = RNN(voc_size,emb_size,hidden_size,n_layers,out_size)
critirion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(),lr = learning_rate)


In [232]:
for name,para in model.named_parameters():
    print(name,para.shape)

emb.weight torch.Size([46, 8])
rnn.weight_ih_l0 torch.Size([16, 8])
rnn.weight_hh_l0 torch.Size([16, 16])
rnn.bias_ih_l0 torch.Size([16])
rnn.bias_hh_l0 torch.Size([16])
fc.weight torch.Size([1, 16])
fc.bias torch.Size([1])


In [233]:
for epoch in range(epochs):
    for x,y in dataloader:
        pre = model(x)
        loss = critirion(pre.view(-1),y.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print("Epoch: {}/{} Loss:{} ".format(epoch,epochs,loss))       

Epoch: 0/2000 Loss:0.6896926760673523 
Epoch: 1/2000 Loss:0.6408464312553406 
Epoch: 2/2000 Loss:0.5962398648262024 
Epoch: 3/2000 Loss:0.5535113215446472 
Epoch: 4/2000 Loss:0.5116739869117737 
Epoch: 5/2000 Loss:0.4701729714870453 
Epoch: 6/2000 Loss:0.4287409484386444 
Epoch: 7/2000 Loss:0.38727131485939026 
Epoch: 8/2000 Loss:0.3455752432346344 
Epoch: 9/2000 Loss:0.30345383286476135 
Epoch: 10/2000 Loss:0.26141461730003357 
Epoch: 11/2000 Loss:0.22101879119873047 
Epoch: 12/2000 Loss:0.18408986926078796 
Epoch: 13/2000 Loss:0.15162762999534607 
Epoch: 14/2000 Loss:0.12378551065921783 
Epoch: 15/2000 Loss:0.10045023262500763 
Epoch: 16/2000 Loss:0.081368587911129 
Epoch: 17/2000 Loss:0.06609045714139938 
Epoch: 18/2000 Loss:0.05402151122689247 
Epoch: 19/2000 Loss:0.04453780874609947 
Epoch: 20/2000 Loss:0.037074487656354904 
Epoch: 21/2000 Loss:0.03116586245596409 
Epoch: 22/2000 Loss:0.026448650285601616 
Epoch: 23/2000 Loss:0.022647710517048836 
Epoch: 24/2000 Loss:0.01955727301

Epoch: 199/2000 Loss:0.000625335262157023 
Epoch: 200/2000 Loss:0.0006215053726918995 
Epoch: 201/2000 Loss:0.0006177223986014724 
Epoch: 202/2000 Loss:0.000613943615462631 
Epoch: 203/2000 Loss:0.0006102289189584553 
Epoch: 204/2000 Loss:0.0006065694615244865 
Epoch: 205/2000 Loss:0.0006029057549312711 
Epoch: 206/2000 Loss:0.0005992846563458443 
Epoch: 207/2000 Loss:0.000595719029661268 
Epoch: 208/2000 Loss:0.0005921703996136785 
Epoch: 209/2000 Loss:0.0005886601866222918 
Epoch: 210/2000 Loss:0.0005851839669048786 
Epoch: 211/2000 Loss:0.000581737607717514 
Epoch: 212/2000 Loss:0.0005783254164271057 
Epoch: 213/2000 Loss:0.000574955774936825 
Epoch: 214/2000 Loss:0.0005716202431358397 
Epoch: 215/2000 Loss:0.0005683018243871629 
Epoch: 216/2000 Loss:0.0005650216480717063 
Epoch: 217/2000 Loss:0.0005617713322862983 
Epoch: 218/2000 Loss:0.0005585466860793531 
Epoch: 219/2000 Loss:0.0005553431692533195 
Epoch: 220/2000 Loss:0.0005521823186427355 
Epoch: 221/2000 Loss:0.00054905132856

Epoch: 412/2000 Loss:0.00023900081578176469 
Epoch: 413/2000 Loss:0.00023819170019123703 
Epoch: 414/2000 Loss:0.00023738261370453984 
Epoch: 415/2000 Loss:0.00023656497069168836 
Epoch: 416/2000 Loss:0.00023576438252348453 
Epoch: 417/2000 Loss:0.00023496376525145024 
Epoch: 418/2000 Loss:0.00023417166084982455 
Epoch: 419/2000 Loss:0.00023337535094469786 
Epoch: 420/2000 Loss:0.00023259605222847313 
Epoch: 421/2000 Loss:0.00023182100267149508 
Epoch: 422/2000 Loss:0.000231058758799918 
Epoch: 423/2000 Loss:0.00023027091810945421 
Epoch: 424/2000 Loss:0.0002295214362675324 
Epoch: 425/2000 Loss:0.0002287293755216524 
Epoch: 426/2000 Loss:0.0002279883628943935 
Epoch: 427/2000 Loss:0.00022723039728589356 
Epoch: 428/2000 Loss:0.00022647665173280984 
Epoch: 429/2000 Loss:0.00022573141905013472 
Epoch: 430/2000 Loss:0.0002249819372082129 
Epoch: 431/2000 Loss:0.00022423671907745302 
Epoch: 432/2000 Loss:0.00022351274674292654 
Epoch: 433/2000 Loss:0.00022277179232332855 
Epoch: 434/2000 

Epoch: 601/2000 Loss:0.000138275368954055 
Epoch: 602/2000 Loss:0.00013794323604088277 
Epoch: 603/2000 Loss:0.0001375812862534076 
Epoch: 604/2000 Loss:0.00013725768076255918 
Epoch: 605/2000 Loss:0.00013691278581973165 
Epoch: 606/2000 Loss:0.00013659341493621469 
Epoch: 607/2000 Loss:0.00013624427083414048 
Epoch: 608/2000 Loss:0.00013591638708021492 
Epoch: 609/2000 Loss:0.0001355842687189579 
Epoch: 610/2000 Loss:0.00013524787209462374 
Epoch: 611/2000 Loss:0.0001349200028926134 
Epoch: 612/2000 Loss:0.0001345751079497859 
Epoch: 613/2000 Loss:0.00013426427904050797 
Epoch: 614/2000 Loss:0.00013392788241617382 
Epoch: 615/2000 Loss:0.00013360001321416348 
Epoch: 616/2000 Loss:0.00013329343346413225 
Epoch: 617/2000 Loss:0.00013295705139171332 
Epoch: 618/2000 Loss:0.0001326419587712735 
Epoch: 619/2000 Loss:0.00013231835328042507 
Epoch: 620/2000 Loss:0.00013200750981923193 
Epoch: 621/2000 Loss:0.0001316881534876302 
Epoch: 622/2000 Loss:0.00013136028428561985 
Epoch: 623/2000 Lo

Epoch: 797/2000 Loss:9.00198720046319e-05 
Epoch: 798/2000 Loss:8.984528540167958e-05 
Epoch: 799/2000 Loss:8.966644963948056e-05 
Epoch: 800/2000 Loss:8.950890332926065e-05 
Epoch: 801/2000 Loss:8.931305637815967e-05 
Epoch: 802/2000 Loss:8.914699719753116e-05 
Epoch: 803/2000 Loss:8.898094529286027e-05 
Epoch: 804/2000 Loss:8.880211680661887e-05 
Epoch: 805/2000 Loss:8.862327376846224e-05 
Epoch: 806/2000 Loss:8.846574928611517e-05 
Epoch: 807/2000 Loss:8.829543367028236e-05 
Epoch: 808/2000 Loss:8.812510350253433e-05 
Epoch: 809/2000 Loss:8.796756446827203e-05 
Epoch: 810/2000 Loss:8.777597395237535e-05 
Epoch: 811/2000 Loss:8.760990749578923e-05 
Epoch: 812/2000 Loss:8.745236846152693e-05 
Epoch: 813/2000 Loss:8.729483670322224e-05 
Epoch: 814/2000 Loss:8.710323163541034e-05 
Epoch: 815/2000 Loss:8.694142888998613e-05 
Epoch: 816/2000 Loss:8.679665916133672e-05 
Epoch: 817/2000 Loss:8.662635082146153e-05 
Epoch: 818/2000 Loss:8.644326590001583e-05 
Epoch: 819/2000 Loss:8.6281463154

Epoch: 995/2000 Loss:6.342135748127475e-05 
Epoch: 996/2000 Loss:6.331066106213257e-05 
Epoch: 997/2000 Loss:6.321274122456089e-05 
Epoch: 998/2000 Loss:6.311479955911636e-05 
Epoch: 999/2000 Loss:6.301262328634039e-05 
Epoch: 1000/2000 Loss:6.290617602644488e-05 
Epoch: 1001/2000 Loss:6.280399247771129e-05 
Epoch: 1002/2000 Loss:6.269754521781579e-05 
Epoch: 1003/2000 Loss:6.25953616690822e-05 
Epoch: 1004/2000 Loss:6.24974345555529e-05 
Epoch: 1005/2000 Loss:6.24080203124322e-05 
Epoch: 1006/2000 Loss:6.230583676369861e-05 
Epoch: 1007/2000 Loss:6.21993895038031e-05 
Epoch: 1008/2000 Loss:6.210572610143572e-05 
Epoch: 1009/2000 Loss:6.200779171194881e-05 
Epoch: 1010/2000 Loss:6.190134445205331e-05 
Epoch: 1011/2000 Loss:6.180342461448163e-05 
Epoch: 1012/2000 Loss:6.172252324176952e-05 
Epoch: 1013/2000 Loss:6.162458885228261e-05 
Epoch: 1014/2000 Loss:6.150537228677422e-05 
Epoch: 1015/2000 Loss:6.141170160844922e-05 
Epoch: 1016/2000 Loss:6.133081478765234e-05 
Epoch: 1017/2000 Lo

Epoch: 1185/2000 Loss:4.768485450767912e-05 
Epoch: 1186/2000 Loss:4.762524986290373e-05 
Epoch: 1187/2000 Loss:4.7548612201353535e-05 
Epoch: 1188/2000 Loss:4.7501780500169843e-05 
Epoch: 1189/2000 Loss:4.741662633023225e-05 
Epoch: 1190/2000 Loss:4.7365530917886645e-05 
Epoch: 1191/2000 Loss:4.728889325633645e-05 
Epoch: 1192/2000 Loss:4.722928497358225e-05 
Epoch: 1193/2000 Loss:4.715690738521516e-05 
Epoch: 1194/2000 Loss:4.710581197286956e-05 
Epoch: 1195/2000 Loss:4.702491787611507e-05 
Epoch: 1196/2000 Loss:4.696956602856517e-05 
Epoch: 1197/2000 Loss:4.6888675569789484e-05 
Epoch: 1198/2000 Loss:4.684609666583128e-05 
Epoch: 1199/2000 Loss:4.6765198931097984e-05 
Epoch: 1200/2000 Loss:4.671410351875238e-05 
Epoch: 1201/2000 Loss:4.66332130599767e-05 
Epoch: 1202/2000 Loss:4.6582117647631094e-05 
Epoch: 1203/2000 Loss:4.650973278330639e-05 
Epoch: 1204/2000 Loss:4.645864828489721e-05 
Epoch: 1205/2000 Loss:4.638626705855131e-05 
Epoch: 1206/2000 Loss:4.632666241377592e-05 
Epoch

Epoch: 1371/2000 Loss:3.736004146048799e-05 
Epoch: 1372/2000 Loss:3.728765659616329e-05 
Epoch: 1373/2000 Loss:3.7249341403367e-05 
Epoch: 1374/2000 Loss:3.71982496290002e-05 
Epoch: 1375/2000 Loss:3.71514142898377e-05 
Epoch: 1376/2000 Loss:3.710458258865401e-05 
Epoch: 1377/2000 Loss:3.7066263757878914e-05 
Epoch: 1378/2000 Loss:3.702794128912501e-05 
Epoch: 1379/2000 Loss:3.697685679071583e-05 
Epoch: 1380/2000 Loss:3.692575774039142e-05 
Epoch: 1381/2000 Loss:3.688318611239083e-05 
Epoch: 1382/2000 Loss:3.683635077322833e-05 
Epoch: 1383/2000 Loss:3.6789515434065834e-05 
Epoch: 1384/2000 Loss:3.6751200241269544e-05 
Epoch: 1385/2000 Loss:3.670436126412824e-05 
Epoch: 1386/2000 Loss:3.664901669253595e-05 
Epoch: 1387/2000 Loss:3.661495429696515e-05 
Epoch: 1388/2000 Loss:3.656811895780265e-05 
Epoch: 1389/2000 Loss:3.6517030821414664e-05 
Epoch: 1390/2000 Loss:3.646593540906906e-05 
Epoch: 1391/2000 Loss:3.6436133086681366e-05 
Epoch: 1392/2000 Loss:3.638504495029338e-05 
Epoch: 13

Epoch: 1559/2000 Loss:2.9900695153628476e-05 
Epoch: 1560/2000 Loss:2.9879405701649375e-05 
Epoch: 1561/2000 Loss:2.9858116249670275e-05 
Epoch: 1562/2000 Loss:2.9832570362486877e-05 
Epoch: 1563/2000 Loss:2.9764449209324084e-05 
Epoch: 1564/2000 Loss:2.9743159757344984e-05 
Epoch: 1565/2000 Loss:2.972187394334469e-05 
Epoch: 1566/2000 Loss:2.9696328056161292e-05 
Epoch: 1567/2000 Loss:2.9640981665579602e-05 
Epoch: 1568/2000 Loss:2.9611177524202503e-05 
Epoch: 1569/2000 Loss:2.95941445074277e-05 
Epoch: 1570/2000 Loss:2.9551569241448306e-05 
Epoch: 1571/2000 Loss:2.9504737540264614e-05 
Epoch: 1572/2000 Loss:2.9487706342479214e-05 
Epoch: 1573/2000 Loss:2.9449385692714714e-05 
Epoch: 1574/2000 Loss:2.941532693512272e-05 
Epoch: 1575/2000 Loss:2.9372751669143327e-05 
Epoch: 1576/2000 Loss:2.9355720471357927e-05 
Epoch: 1577/2000 Loss:2.932591814897023e-05 
Epoch: 1578/2000 Loss:2.9283341064001434e-05 
Epoch: 1579/2000 Loss:2.9253538741613738e-05 
Epoch: 1580/2000 Loss:2.922799103544093

Epoch: 1753/2000 Loss:2.4301960365846753e-05 
Epoch: 1754/2000 Loss:2.4267899789265357e-05 
Epoch: 1755/2000 Loss:2.425938509986736e-05 
Epoch: 1756/2000 Loss:2.4221068088081665e-05 
Epoch: 1757/2000 Loss:2.420403870928567e-05 
Epoch: 1758/2000 Loss:2.416146162431687e-05 
Epoch: 1759/2000 Loss:2.414868686173577e-05 
Epoch: 1760/2000 Loss:2.4123144612531178e-05 
Epoch: 1761/2000 Loss:2.4110373487928882e-05 
Epoch: 1762/2000 Loss:2.4080563889583573e-05 
Epoch: 1763/2000 Loss:2.4042250515776686e-05 
Epoch: 1764/2000 Loss:2.4029475753195584e-05 
Epoch: 1765/2000 Loss:2.3995415176614188e-05 
Epoch: 1766/2000 Loss:2.3991156922420487e-05 
Epoch: 1767/2000 Loss:2.3948583475430496e-05 
Epoch: 1768/2000 Loss:2.3918779334053397e-05 
Epoch: 1769/2000 Loss:2.39017499552574e-05 
Epoch: 1770/2000 Loss:2.3876205887063406e-05 
Epoch: 1771/2000 Loss:2.3859174689278007e-05 
Epoch: 1772/2000 Loss:2.3829370547900908e-05 
Epoch: 1773/2000 Loss:2.3799568225513212e-05 
Epoch: 1774/2000 Loss:2.378253520873841e

Epoch: 1941/2000 Loss:2.0176375983282924e-05 
Epoch: 1942/2000 Loss:2.0155088350293227e-05 
Epoch: 1943/2000 Loss:2.0146571841905825e-05 
Epoch: 1944/2000 Loss:2.0129546101088636e-05 
Epoch: 1945/2000 Loss:2.011676951951813e-05 
Epoch: 1946/2000 Loss:2.0082710761926137e-05 
Epoch: 1947/2000 Loss:2.005716487474274e-05 
Epoch: 1948/2000 Loss:2.0044390112161636e-05 
Epoch: 1949/2000 Loss:2.0027362552355044e-05 
Epoch: 1950/2000 Loss:2.0010333173559047e-05 
Epoch: 1951/2000 Loss:1.9997558410977945e-05 
Epoch: 1952/2000 Loss:1.9980527213192545e-05 
Epoch: 1953/2000 Loss:1.9946468455600552e-05 
Epoch: 1954/2000 Loss:1.993369369301945e-05 
Epoch: 1955/2000 Loss:1.9920922568417154e-05 
Epoch: 1956/2000 Loss:1.9891118427040055e-05 
Epoch: 1957/2000 Loss:1.9878349121427163e-05 
Epoch: 1958/2000 Loss:1.986557435884606e-05 
Epoch: 1959/2000 Loss:1.9844288544845767e-05 
Epoch: 1960/2000 Loss:1.981448440346867e-05 
Epoch: 1961/2000 Loss:1.9810224330285564e-05 
Epoch: 1962/2000 Loss:1.977616557269357

In [234]:
y_score = model(x).detach().numpy().reshape(-1)

In [235]:
from sklearn.metrics import roc_auc_score

roc_auc_score(labels.numpy(),y_score)

0.40816326530612246

In [124]:
class Model(nn.Module):
    def __init__(self,voc_size,emb_size):
        super().__init__()
        self.emb = nn.Embedding(voc_size,emb_size,padding_idx=0)
        self.fc = nn.Linear(emb_size,voc_size)
    
    def forward(self,x):jioayujioa
        emb = self.emb(x)
        out = self.fc(emb)
        return out

model = Model(10,3)
input = torch.LongTensor([1,2,3,0,0,0,1,2,3])
y = torch.LongTensor([0,0,1,2,1,2,1,2,1])
optimizer = optim.Adam(model.parameters(),lr = 0.1)
critirion = nn.CrossEntropyLoss()


In [130]:
next(model.parameters()).data

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-2.5598e+00, -1.4235e+00,  5.1625e-02],
        [ 1.5118e-03, -9.5410e-01,  1.6400e+00],
        [-1.3023e-01,  4.7651e-01,  1.4844e+00],
        [-8.4063e-01, -1.6716e-02,  1.6845e+00],
        [ 7.9070e-01,  1.0001e-01, -6.6812e-01],
        [-7.9196e-01, -1.9943e+00, -2.6602e+00],
        [-1.3921e+00, -5.0291e-02,  8.4315e-01],
        [ 1.7655e+00,  2.4189e+00, -9.6869e-01],
        [-2.8445e-01,  1.7152e-01, -1.8190e+00]])

In [46]:
for name,para in model.named_parameters():
    print(para)

Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000],
        [-0.1631,  0.6614,  1.1899],
        [ 0.8165, -0.9135, -0.3538],
        [ 0.7639, -0.5890, -0.7636],
        [ 1.3352,  0.6043,  1.3275],
        [-0.4954,  1.5496,  0.3476],
        [ 0.0930,  0.6147,  0.7124],
        [-1.7765,  3.3212, -0.4021],
        [-0.7123, -0.6200, -0.2281],
        [-0.7893, -1.6111, -1.8716]], requires_grad=True)
Parameter containing:
tensor([[-0.3948, -0.4848, -0.2646],
        [-0.0672, -0.3539,  0.2112],
        [ 0.1787, -0.1307,  0.2219],
        [ 0.1866,  0.3525,  0.3888],
        [-0.1955,  0.5641, -0.0667],
        [-0.0198, -0.5449, -0.3716],
        [-0.3373, -0.2469,  0.4105],
        [-0.1887, -0.4314,  0.2221],
        [ 0.1848,  0.3739, -0.2988],
        [ 0.1252, -0.2102, -0.1297]], requires_grad=True)
Parameter containing:
tensor([-0.4601, -0.2631, -0.1768,  0.2469,  0.1055,  0.1426,  0.5763,  0.5627,
         0.3938,  0.0184], requires_grad=True)


In [47]:
for epoch in range(10):
    pre = model(input)
    loss = critirion(pre,y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [49]:
for para in model.parameters():
    print(para)

Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.2238, -0.3439,  0.9977],
        [ 1.1004, -1.8155, -0.8346],
        [ 1.4325, -1.5533, -0.0548],
        [ 1.3352,  0.6043,  1.3275],
        [-0.4954,  1.5496,  0.3476],
        [ 0.0930,  0.6147,  0.7124],
        [-1.7765,  3.3212, -0.4021],
        [-0.7123, -0.6200, -0.2281],
        [-0.7893, -1.6111, -1.8716]], requires_grad=True)
Parameter containing:
tensor([[ 0.3356, -1.4843,  0.5877],
        [ 0.8566, -1.3078,  0.3889],
        [ 0.5409, -0.6523, -0.6968],
        [-0.6412,  0.9468, -0.4744],
        [-0.9818,  1.2005, -0.9720],
        [-0.8097,  0.3021,  0.1341],
        [-1.1153,  0.6748, -0.5426],
        [-0.9828,  0.4701, -0.5090],
        [-0.5445,  1.2259, -0.7624],
        [-0.6299,  0.6319, -0.0675]], requires_grad=True)
Parameter containing:
tensor([ 0.4416,  0.6831,  0.7453, -0.6819, -0.8263, -0.7828, -0.3597, -0.3727,
        -0.5342, -0.9156], requires_grad=True)
