Skip to content

Commit

Permalink
FashionMNIST用意
Browse files Browse the repository at this point in the history
XOR学習関数微修正
  • Loading branch information
HarugumoFM committed Mar 6, 2024
1 parent e2b951c commit 4408fd2
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 78 deletions.
1 change: 1 addition & 0 deletions ConsoleML/ConsoleML.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
<ItemGroup>
<PackageReference Include="libtorch-cpu" Version="2.2.1.1" />
<PackageReference Include="TorchSharp" Version="0.102.2" />
<PackageReference Include="TorchVision" Version="0.102.2" />
</ItemGroup>

<ItemGroup>
Expand Down
193 changes: 115 additions & 78 deletions ConsoleML/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,107 +7,128 @@
using System.Globalization;
using System.Text;
using Shimotsuki.Models;
using System.Data;
using static Function;

nnTest();

// See https://aka.ms/new-console-template for more information

var langE = new Lang();
var langF = new Lang();
var pairs = new List<string[]>();
int maxPairs = 500;
//read English-Francis Pair
using (var reader = new StreamReader("eng-fra.txt")) {
string line;
int i = 0;
while((line = reader.ReadLine()) !=null){
var pair = line.Split(' ');

if (pair[0].Split().Length < 8 && pair[0].Split().Length >3 && pair[1].Split().Length < 8) {
langE.addSentence(NormalizeString(pair[0]));
langF.addSentence(NormalizeString(pair[1]));
i++;
pairs.Add(pair);
}
if (maxPairs == i)
break;
}
}
Console.WriteLine("get " + pairs.Count + " pairs");
Console.WriteLine(langE.word2Index.Count+" English words");
Console.WriteLine(langF.word2Index.Count+" Frances words");

int hiddenSize = 128;

var model = new AttnSeq2Seq(langE.word2Index.Count, hiddenSize, langF.word2Index.Count);

model.LangE = langE;
model.LangF = langF;
trainFashionMNIST();

model.trainAll(pairs, 10);
//train function

var model2 = new AttnSeq2Seq(langE.word2Index.Count, hiddenSize, langF.word2Index.Count);
void trainFashionMNIST()
{
//starting download
Console.WriteLine("start download data");
var datasetPath = Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData);
var train_data = torchvision.datasets.FashionMNIST(datasetPath, true, download: true);
var test_data = torchvision.datasets.FashionMNIST(datasetPath, false, download: true);
Console.WriteLine("success download data");

model2.load("model.bin");

model2.LangE = langE;
model2.LangF = langF;
int index = 0;
foreach(var pair in pairs) {
Console.WriteLine(string.Join(" ",pair[0]));
no_grad();
var input = tensorFromSentence(model2.LangE, NormalizeString(pair[0]));
Console.WriteLine("answer: "+ pair[1]);
Console.WriteLine("predict: " + model2.evaluate(input, 10));
index++;
if (index > 30)
break;
}

void trainSeq2Seq() {

Check warning on line 34 in ConsoleML/Program.cs

View workflow job for this annotation

GitHub Actions / build

The local function 'trainSeq2Seq' is declared but never used

Check warning on line 34 in ConsoleML/Program.cs

View workflow job for this annotation

GitHub Actions / build

The local function 'trainSeq2Seq' is declared but never used

var langE = new Lang();
var langF = new Lang();
var pairs = new List<string[]>();
int maxPairs = 500;
//read English-Francis Pair
using (var reader = new StreamReader("eng-fra.txt")) {
string line;
int i = 0;
while ((line = reader.ReadLine()) != null) {

Check warning on line 44 in ConsoleML/Program.cs

View workflow job for this annotation

GitHub Actions / build

Converting null literal or possible null value to non-nullable type.

Check warning on line 44 in ConsoleML/Program.cs

View workflow job for this annotation

GitHub Actions / build

Converting null literal or possible null value to non-nullable type.
var pair = line.Split(' ');

if (pair[0].Split().Length < 8 && pair[0].Split().Length > 3 && pair[1].Split().Length < 8) {
langE.addSentence(NormalizeString(pair[0]));
langF.addSentence(NormalizeString(pair[1]));
i++;
pairs.Add(pair);
}
if (maxPairs == i)
break;
}
}
Console.WriteLine("get " + pairs.Count + " pairs");
Console.WriteLine(langE.word2Index.Count + " English words");
Console.WriteLine(langF.word2Index.Count + " Frances words");

int hiddenSize = 128;

var model = new AttnSeq2Seq(langE.word2Index.Count, hiddenSize, langF.word2Index.Count);

model.LangE = langE;
model.LangF = langF;

model.trainAll(pairs, 10);

var model2 = new AttnSeq2Seq(langE.word2Index.Count, hiddenSize, langF.word2Index.Count);

model2.load("model.bin");

model2.LangE = langE;
model2.LangF = langF;
int index = 0;
foreach (var pair in pairs) {
Console.WriteLine(string.Join(" ", pair[0]));
no_grad();
var input = tensorFromSentence(model2.LangE, NormalizeString(pair[0]));
Console.WriteLine("answer: " + pair[1]);
Console.WriteLine("predict: " + model2.evaluate(input, 10));
index++;
if (index > 30)
break;
}


///Function
static string UnicodeToAscii(string s) {
string normalizedString = s.Normalize(NormalizationForm.FormKD);
StringBuilder stringBuilder = new StringBuilder();

foreach (char c in normalizedString) {
UnicodeCategory unicodeCategory = CharUnicodeInfo.GetUnicodeCategory(c);
if (unicodeCategory != UnicodeCategory.NonSpacingMark) {
stringBuilder.Append(c);


///Function
static string UnicodeToAscii(string s) {
string normalizedString = s.Normalize(NormalizationForm.FormKD);
StringBuilder stringBuilder = new StringBuilder();

foreach (char c in normalizedString) {
UnicodeCategory unicodeCategory = CharUnicodeInfo.GetUnicodeCategory(c);
if (unicodeCategory != UnicodeCategory.NonSpacingMark) {
stringBuilder.Append(c);
}
}

return stringBuilder.ToString();
}

return stringBuilder.ToString();
}

static string NormalizeString(string s) {
s = UnicodeToAscii(s.ToLower().Trim());
s = Regex.Replace(s, @"([.!?])", @" $1");
s = Regex.Replace(s, @"[^a-zA-Z.!?]+", " ");
return s;
}

static string NormalizeString(string s) {
s = UnicodeToAscii(s.ToLower().Trim());
s = Regex.Replace(s, @"([.!?])", @" $1");
s = Regex.Replace(s, @"[^a-zA-Z.!?]+", " ");
return s;
}
static List<long> indexesFromSentence(Lang lang, string sentence) {
var res = new List<long>();
foreach (var word in sentence.Split()) {
res.Add(lang.word2Index[word]);
}
return res;
}

static List<long> indexesFromSentence(Lang lang, string sentence) {
var res = new List<long>();
foreach (var word in sentence.Split()) {
res.Add(lang.word2Index[word]);
static Tensor tensorFromSentence(Lang lang, string sentence) {
var index = indexesFromSentence(lang, sentence);
index.Add(1);
return tensor(index).view(new long[] { -1, 1 });
}
return res;
}

static Tensor tensorFromSentence(Lang lang, string sentence) {
var index = indexesFromSentence(lang, sentence);
index.Add(1);
return tensor(index).view(new long[] { -1, 1 });
}



/// Model

class Net:nn.Module
Expand Down Expand Up @@ -141,15 +162,18 @@ public static void nnTest()
//テストデータ作成
var trainData = new float[,]
{
{ 0, 0 },
{ 1, 0 },
{ 0, 1 },
{ 1, 1 },
{ 0, 0 },
{ 1, 0 },
{ 0, 1 },
{ 1, 1 },
};
//ラベル
var trainLabel = new float[,]
{
{0},{1},{1},{0},
{0},
{1},
{1},
{0},
};
var model = new Net();
model.train();
Expand All @@ -169,8 +193,21 @@ public static void nnTest()
optimizer.zero_grad();
loss.backward();
optimizer.step();
if (ep % 100 == 0)
if (ep % 500 == 0)
{
Console.WriteLine($"Epoch:{ep} Loss:{loss.ToSingle()}");
model.eval();
Console.Write("x1:0,x2:0 result:");
model.forward(torch.tensor(new float[] { 0, 0 })).print();
Console.Write("x1:0,x2:1 result:");
model.forward(torch.tensor(new float[] { 0, 1 })).print();
Console.Write("x1:1,x2:0 result:");
model.forward(torch.tensor(new float[] { 1, 0 })).print();
Console.Write("x1:1,x2:1 result:");
model.forward(torch.tensor(new float[] { 1, 1 })).print();

model.train();
}
}
//評価1
model.eval();
Expand Down

0 comments on commit 4408fd2

Please sign in to comment.